[VLM] Support chunk prefill for VLM (#6355)
Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
@@ -116,6 +116,10 @@ class ModelConfig:
|
||||
self.is_audio_model = enable_multimodal and is_audio_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_chunked_prefill_supported = (
|
||||
enable_multimodal
|
||||
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
||||
)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
@@ -574,6 +578,21 @@ def is_encoder_decoder_model(model_architectures: List[str]):
|
||||
return "MllamaForConditionalGeneration" in model_architectures
|
||||
|
||||
|
||||
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
|
||||
"""Check if chunked prefill is supported for a MultiModal model."""
|
||||
unsupported = [
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"LlavaLlamaForCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
"CLIPModel",
|
||||
]
|
||||
if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
|
||||
@@ -16,10 +16,15 @@ from sglang.srt.managers.schedule_batch import (
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
||||
from sglang.utils import logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
|
||||
# to ensure consistent logging behavior across the codebase. This prevents issues with log
|
||||
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
||||
# in the console when multimodal support is enabled.
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPattern:
|
||||
@@ -189,26 +194,137 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
||||
return output_ids_tensor.tolist()
|
||||
|
||||
|
||||
embedding_cache = None
|
||||
|
||||
|
||||
def init_embedding_cache(max_size: int):
|
||||
global embedding_cache
|
||||
embedding_cache = MultiModalCache(max_size)
|
||||
|
||||
|
||||
def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int:
|
||||
hash_list = [item.hash for item in embedding_items]
|
||||
return hash(tuple(hash_list))
|
||||
|
||||
|
||||
def get_embedding_chunk(
|
||||
embedding: torch.Tensor,
|
||||
extend_prefix_len: int,
|
||||
extend_seq_len: int,
|
||||
items_offset: List[Tuple[int, int]],
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
"""
|
||||
Extract a chunk of embeddings based on the specified prefix length, sequence length, and offset ranges.
|
||||
|
||||
Args:
|
||||
embedding: The full embedding tensor to extract a chunk from
|
||||
extend_prefix_len: The starting position (prefix length) for extraction
|
||||
extend_seq_len: The number of tokens to extract
|
||||
items_offset: List of [start, end] offset ranges for multimodal items in the input sequence
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The extracted embedding chunk as a tensor
|
||||
- The start index used for extraction
|
||||
- The end index used for extraction
|
||||
|
||||
Note:
|
||||
If there's no overlap between the requested range and the offset ranges,
|
||||
an empty tensor is returned with zeros for start and end indices.
|
||||
"""
|
||||
start_index, end_index = 0, 0
|
||||
extend_start_index = extend_prefix_len
|
||||
extend_end_index = extend_prefix_len + extend_seq_len - 1
|
||||
|
||||
for start, end in items_offset:
|
||||
if extend_start_index >= start and extend_start_index <= end:
|
||||
start_index += extend_start_index - start
|
||||
elif extend_start_index > end:
|
||||
start_index += end - start + 1
|
||||
|
||||
if extend_end_index >= start and extend_end_index <= end:
|
||||
end_index += extend_end_index - start + 1
|
||||
elif extend_end_index > end:
|
||||
end_index += end - start + 1
|
||||
# some models embedding is 3-dim, reshape it to 2-dim
|
||||
embedding = embedding.reshape(-1, embedding.shape[-1])
|
||||
embedding_chunk = embedding[start_index:end_index]
|
||||
return embedding_chunk, start_index, end_index
|
||||
|
||||
|
||||
def get_embedding_and_mask(
|
||||
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
||||
embedding_items: List[MultimodalDataItem],
|
||||
placeholder_tensor: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
):
|
||||
items_size: List[int],
|
||||
prefix_length: List[int],
|
||||
extend_length: List[int],
|
||||
items_offset_list: List[List[Tuple[int, int]]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get the multimodal embedding and its mask from input_ids
|
||||
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
|
||||
|
||||
Args:
|
||||
data_embedding_func: Function that generates embeddings for multimodal items
|
||||
embedding_items: List of multimodal items to embed
|
||||
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
|
||||
input_ids: The input token IDs tensor
|
||||
items_size: Cumulative sizes of multimodal items per request
|
||||
prefix_length: Prefix lengths for each request
|
||||
extend_length: Sequence lengths for each request
|
||||
items_offset_list: List of offset ranges for multimodal items in each request
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The generated embeddings tensor
|
||||
- A boolean mask tensor indicating where these embeddings should be placed
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of multimodal tokens in input_ids doesn't match
|
||||
the number of tokens in the generated embeddings
|
||||
"""
|
||||
# 1. Get the embedding
|
||||
embedding = data_embedding_func(embedding_items)
|
||||
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
||||
embedding_list = []
|
||||
for i in range(len(items_size) - 1):
|
||||
if items_size[i] == items_size[i + 1]:
|
||||
continue
|
||||
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
||||
items_offset = items_offset_list[i]
|
||||
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
|
||||
# if all items has been prefixed, we do not need to calculate embedding
|
||||
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
|
||||
continue
|
||||
embedding_per_req = embedding_cache.get(embedding_items_hash)
|
||||
if embedding_per_req is None:
|
||||
embedding_per_req = data_embedding_func(embedding_items_per_req)
|
||||
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
|
||||
print_warning_once(
|
||||
"Multimodal embedding cache is full. Consider increasing the "
|
||||
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
|
||||
)
|
||||
|
||||
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
|
||||
embedding=embedding_per_req,
|
||||
extend_prefix_len=prefix_length[i],
|
||||
extend_seq_len=extend_length[i],
|
||||
items_offset=items_offset,
|
||||
)
|
||||
# remove this item from cache if chunk reaches to the end
|
||||
embedding_per_req_length = (
|
||||
embedding_per_req.shape[0]
|
||||
if embedding_per_req.dim() == 2
|
||||
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
|
||||
)
|
||||
if end_index == embedding_per_req_length:
|
||||
embedding_cache.free(embedding_items_hash)
|
||||
embedding_list.append(embedding_per_req_chunk)
|
||||
if len(embedding_list) == 0:
|
||||
return None, None
|
||||
embedding = torch.concat(embedding_list, dim=0)
|
||||
# 2. Check the embedding
|
||||
if embedding.dim() == 2:
|
||||
num_mm_tokens_in_embedding = embedding.shape[0]
|
||||
else:
|
||||
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
|
||||
|
||||
# the mask of multimodal tokens from input_ids
|
||||
num_mm_tokens_in_embedding = embedding.shape[0]
|
||||
special_multimodal_mask = torch.isin(
|
||||
input_ids,
|
||||
placeholder_tensor,
|
||||
@@ -222,9 +338,6 @@ def get_embedding_and_mask(
|
||||
"tokens from multimodal embeddings."
|
||||
)
|
||||
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
||||
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
||||
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
|
||||
# extend_start_loc and extend_seq_lens
|
||||
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
||||
if chunked_prefill_size != -1:
|
||||
logger.warning(
|
||||
@@ -245,7 +358,9 @@ def get_embedding_and_mask(
|
||||
|
||||
|
||||
def embed_mm_inputs(
|
||||
mm_inputs: MultimodalInputs,
|
||||
mm_inputs_list: List[MultimodalInputs],
|
||||
extend_prefix_lens: List[int],
|
||||
extend_seq_lens: List[int],
|
||||
input_ids: torch.Tensor,
|
||||
input_embedding: nn.Embedding,
|
||||
image_data_embedding_func: Callable[
|
||||
@@ -257,125 +372,133 @@ def embed_mm_inputs(
|
||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
|
||||
Embed multimodal inputs and integrate them with text token embeddings.
|
||||
|
||||
Args:
|
||||
placeholder_tokens: denoting the token of multimodal data in input_ids.
|
||||
If none, the pad_values of multimodal items are used
|
||||
Args:
|
||||
mm_inputs_list: List of multimodal inputs to process
|
||||
extend_prefix_lens: Prefix lengths for each request
|
||||
extend_seq_lens: Sequence lengths for each request
|
||||
input_ids: Input token IDs tensor
|
||||
input_embedding: Embedding layer for text tokens
|
||||
image_data_embedding_func: Function to embed image data
|
||||
audio_data_embedding_func: Function to embed audio data
|
||||
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
|
||||
|
||||
Returns:
|
||||
final embedding: Optional[torch.Tensor]
|
||||
Returns:
|
||||
Combined embedding tensor with multimodal content integrated
|
||||
"""
|
||||
|
||||
if mm_inputs is None:
|
||||
if mm_inputs_list is None:
|
||||
return None
|
||||
|
||||
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
||||
# we assume that multimodal data are represented with its pad_values in input_ids
|
||||
# See `pad_input_ids` for more detail
|
||||
item_flatten_list = []
|
||||
for mm_inputs in mm_inputs_list:
|
||||
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
||||
|
||||
# if placeholder_tokens is specified
|
||||
if placeholder_tokens is not None:
|
||||
placeholder_token_ids = flatten_nested_list(
|
||||
[placeholder_token for placeholder_token in placeholder_tokens.values()]
|
||||
embeddings, masks = [], []
|
||||
|
||||
# 2. Get multimodal embedding separately
|
||||
# TODO: make this more generic
|
||||
# Try get image embedding if any
|
||||
if (
|
||||
any(True for item in item_flatten_list if item.is_image())
|
||||
and image_data_embedding_func
|
||||
):
|
||||
items = [item for item in item_flatten_list if item.is_image()]
|
||||
placeholder_tensor = torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
else:
|
||||
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
|
||||
|
||||
assert isinstance(placeholder_token_ids[0], int)
|
||||
|
||||
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
|
||||
|
||||
placeholder_masks = torch.isin(input_ids, placeholder_tensor)
|
||||
|
||||
appearing_pad_values = torch.unique(
|
||||
input_ids[placeholder_masks], return_counts=False
|
||||
)
|
||||
|
||||
if appearing_pad_values.numel() == 0:
|
||||
# all been prefixed
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
else:
|
||||
appearing_items = [
|
||||
item
|
||||
for item in mm_inputs.mm_items
|
||||
if item.pad_value is not None and item.pad_value in appearing_pad_values
|
||||
]
|
||||
|
||||
using_all_items = False
|
||||
if len(appearing_items) == 0:
|
||||
# This happens mostly when arg placeholder_token_ids is passed
|
||||
logger.warning(
|
||||
"No multimodal data item's pad value exist in placeholder ids. Using all items"
|
||||
# calculate per request items length offset
|
||||
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
||||
items_offsets = []
|
||||
for i, mm_inputs in enumerate(mm_inputs_list):
|
||||
image_items = [item for item in mm_inputs.mm_items if item.is_image()]
|
||||
items_size[i + 1] = len(image_items)
|
||||
items_offsets.append(
|
||||
flatten_nested_list(
|
||||
[
|
||||
item.image_offsets
|
||||
for item in mm_inputs.mm_items
|
||||
if item.is_image()
|
||||
]
|
||||
)
|
||||
)
|
||||
using_all_items = True
|
||||
appearing_items = mm_inputs.mm_items
|
||||
items_size = torch.cumsum(items_size, dim=0).tolist()
|
||||
|
||||
embeddings, masks = [], []
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=image_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=placeholder_tensor,
|
||||
input_ids=input_ids,
|
||||
items_size=items_size,
|
||||
prefix_length=extend_prefix_lens,
|
||||
extend_length=extend_seq_lens,
|
||||
items_offset_list=items_offsets,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
|
||||
# 2. Get multimodal embedding separately
|
||||
# TODO: make this more generic
|
||||
# Try get image embedding if any
|
||||
if (
|
||||
any(True for item in appearing_items if item.is_image())
|
||||
and image_data_embedding_func
|
||||
):
|
||||
items = [item for item in appearing_items if item.is_image()]
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=image_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=(
|
||||
# use the specified modality token to identify the location to embed
|
||||
placeholder_tokens[Modality.IMAGE]
|
||||
if using_all_items
|
||||
else torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
),
|
||||
input_ids=input_ids,
|
||||
# Try get audio embedding if any
|
||||
if (
|
||||
any(True for item in item_flatten_list if item.is_audio())
|
||||
and audio_data_embedding_func
|
||||
):
|
||||
items = [item for item in item_flatten_list if item.is_audio()]
|
||||
placeholder_tensor = torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
items_offsets = []
|
||||
# calculate per request items length offset
|
||||
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
||||
for i, mm_inputs in enumerate(mm_inputs_list):
|
||||
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
|
||||
items_size[i + 1] = len(audio_items)
|
||||
items_offsets.append(
|
||||
flatten_nested_list(
|
||||
[
|
||||
item.audio_offsets
|
||||
for item in mm_inputs.mm_items
|
||||
if item.is_audio()
|
||||
]
|
||||
)
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
items_size = torch.cumsum(items_size, dim=0)
|
||||
|
||||
# Try get audio embedding if any
|
||||
if (
|
||||
any(True for item in appearing_items if item.is_audio())
|
||||
and audio_data_embedding_func
|
||||
):
|
||||
items = [item for item in appearing_items if item.is_audio()]
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=audio_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=(
|
||||
placeholder_tokens[Modality.AUDIO]
|
||||
if using_all_items
|
||||
else torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
),
|
||||
input_ids=input_ids,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=audio_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=placeholder_tensor,
|
||||
input_ids=input_ids,
|
||||
items_size=items_size,
|
||||
prefix_length=extend_prefix_lens,
|
||||
extend_length=extend_seq_lens,
|
||||
items_offset_list=items_offsets,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
|
||||
# 3. Get input embeddings
|
||||
vocab_size = input_embedding.num_embeddings
|
||||
# Important: clamp after getting original multimodal regions
|
||||
# Clamp input ids. This is because the input_ids for the multimodal tokens are
|
||||
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=vocab_size - 1)
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
# 3. Get input embeddings
|
||||
vocab_size = input_embedding.num_embeddings
|
||||
# Important: clamp after getting original multimodal regions
|
||||
# Clamp input ids. This is because the input_ids for the multimodal tokens are
|
||||
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=vocab_size - 1)
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
|
||||
# 4. Scatter embeddings into input embedding
|
||||
for embedding, mask in zip(embeddings, masks):
|
||||
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
mask,
|
||||
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
||||
)
|
||||
# 4. scatter embeddings into input embedding
|
||||
for embedding, mask in zip(embeddings, masks):
|
||||
if embedding is None or mask is None:
|
||||
continue
|
||||
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
mask,
|
||||
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
@@ -393,16 +516,19 @@ def general_mm_embed_routine(
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A general wrapper function to get final input embeds from multimodal models with a language model as causal model
|
||||
Process multimodal inputs and forward through language model.
|
||||
|
||||
Args:
|
||||
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
||||
image_data_embedding_func : the function returning the image embedding
|
||||
audio_data_embedding_func : the function returning the image embedding
|
||||
|
||||
Returns:
|
||||
forwarded hidden states
|
||||
Args:
|
||||
input_ids: Input token IDs tensor
|
||||
forward_batch: Batch information for model forward pass
|
||||
language_model: Base language model to use
|
||||
image_data_embedding_func: Function to embed image data
|
||||
audio_data_embedding_func: Function to embed audio data
|
||||
placeholder_tokens: Token IDs for multimodal placeholders
|
||||
**kwargs: Additional arguments passed to language model
|
||||
|
||||
Returns:
|
||||
Hidden states from language model forward pass
|
||||
"""
|
||||
assert hasattr(language_model, "get_input_embeddings")
|
||||
embed_tokens = language_model.get_input_embeddings()
|
||||
@@ -410,9 +536,23 @@ def general_mm_embed_routine(
|
||||
not forward_batch.forward_mode.is_decode()
|
||||
and forward_batch.contains_mm_inputs()
|
||||
):
|
||||
mm_input = forward_batch.merge_mm_inputs()
|
||||
mm_inputs_list = [
|
||||
mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None
|
||||
]
|
||||
extend_prefix_lens = [
|
||||
prefix_len
|
||||
for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu)
|
||||
if forward_batch.mm_inputs[i] is not None
|
||||
]
|
||||
extend_seq_lens = [
|
||||
seq_len
|
||||
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
|
||||
if forward_batch.mm_inputs[i] is not None
|
||||
]
|
||||
inputs_embeds = embed_mm_inputs(
|
||||
mm_inputs=mm_input,
|
||||
mm_inputs_list=mm_inputs_list,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
input_ids=input_ids,
|
||||
input_embedding=embed_tokens,
|
||||
image_data_embedding_func=image_data_embedding_func,
|
||||
|
||||
@@ -5,7 +5,7 @@ import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -343,6 +343,33 @@ class BaseMultimodalProcessor(ABC):
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def get_mm_items_offset(
|
||||
input_ids: torch.Tensor, mm_token_id: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Get a set of range for mm_items from input_ids
|
||||
Example:
|
||||
input_ids = [1, 2, 3, 3, 3, 4, 3, 3]
|
||||
mm_token_id = 3
|
||||
return result = [(2,4),(6,7)]
|
||||
"""
|
||||
mask = input_ids == mm_token_id
|
||||
|
||||
start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
|
||||
end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
|
||||
|
||||
return list(zip(start_positions.tolist(), end_positions.tolist()))
|
||||
|
||||
@staticmethod
|
||||
def get_mm_items_offset_by_pair(
|
||||
input_ids: torch.Tensor, mm_start_id: int, mm_end_id: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
indices_start = (input_ids == mm_start_id).nonzero(as_tuple=True)[0] + 1
|
||||
indices_end = (input_ids == mm_end_id).nonzero(as_tuple=True)[0] - 1
|
||||
|
||||
return list(zip(indices_start.tolist(), indices_end.tolist()))
|
||||
|
||||
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
|
||||
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
|
||||
if not mm_inputs:
|
||||
|
||||
@@ -70,8 +70,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
||||
|
||||
items = []
|
||||
input_ids = res["input_ids"]
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids, mm_token_id=self._processor.image_token_id
|
||||
)
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=res["images"],
|
||||
image_offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
image_emb_mask=images_seq_mask,
|
||||
image_spatial_crop=batched_images_spatial_crop,
|
||||
@@ -80,6 +85,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": res["input_ids"].tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": self._processor.image_token_id,
|
||||
}
|
||||
|
||||
@@ -61,6 +61,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
)
|
||||
|
||||
items = []
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.hf_config.image_token_index,
|
||||
)
|
||||
for i, image in enumerate(base_output.images):
|
||||
if images_are_preprocessed:
|
||||
pixel_values = image.pixel_values
|
||||
@@ -73,12 +78,13 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
pixel_values=pixel_values,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets[i],
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -209,7 +209,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
return None
|
||||
|
||||
pixel_values = torch.cat(pixel_values, dim=0)
|
||||
items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
|
||||
|
||||
for idx, num_patches in enumerate(num_patches_list):
|
||||
image_tokens = (
|
||||
@@ -220,10 +219,21 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
input_text = input_text.replace("<image>", image_tokens, 1)
|
||||
|
||||
tokenizer = self._processor
|
||||
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.img_context_token_id,
|
||||
)
|
||||
items = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
|
||||
.flatten()
|
||||
.tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": items,
|
||||
"im_start_id": self.img_start_token_id,
|
||||
"im_end_id": self.img_end_token_id,
|
||||
|
||||
@@ -45,15 +45,21 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
prompt=base_out.input_text,
|
||||
images=images,
|
||||
)
|
||||
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids, mm_token_id=processor.image_id
|
||||
)
|
||||
return {
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=res["pixel_values"],
|
||||
image_emb_mask=res["images_emb_mask"],
|
||||
image_offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
],
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_start_id": processor.image_start_id,
|
||||
"im_end_id": processor.image_end_id,
|
||||
"im_token_id": processor.image_id,
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
@@ -57,13 +52,19 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
)
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.im_token_id,
|
||||
)
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"],
|
||||
image_grid_thws=ret["image_grid_hws"],
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
],
|
||||
"im_token_id": self.im_token_id,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -69,6 +69,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audio_start_id = tokenizer.audio_start_id
|
||||
audio_end_id = tokenizer.audio_end_id
|
||||
|
||||
im_start_id = tokenizer.im_start_id
|
||||
im_end_id = tokenizer.im_end_id
|
||||
im_token_id = tokenizer.unk_id
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
@@ -104,9 +106,20 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
pixel_values = pixel_values_flat
|
||||
|
||||
items = []
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
|
||||
)
|
||||
slice_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
|
||||
)
|
||||
image_offsets.extend(slice_offsets)
|
||||
image_offsets = sorted(image_offsets)
|
||||
|
||||
if len(pixel_values) != 0:
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
image_offsets=image_offsets,
|
||||
tgt_size=tgt_sizes_flat,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
@@ -117,21 +130,30 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
and res["audio_features"] is not None
|
||||
and len(res["audio_features"]) != 0
|
||||
):
|
||||
if audio_start_id is not None and audio_end_id is not None:
|
||||
audio_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids,
|
||||
mm_start_id=audio_start_id,
|
||||
mm_end_id=audio_end_id,
|
||||
)
|
||||
else:
|
||||
audio_offsets = None
|
||||
item = MultimodalDataItem(
|
||||
audio_features=[res["audio_features"]],
|
||||
audio_feature_lens=res["audio_feature_lens"],
|
||||
audio_offsets=audio_offsets,
|
||||
modality=Modality.AUDIO,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"audio_start_id": audio_start_id,
|
||||
"audio_end_id": audio_end_id,
|
||||
"im_token_id": im_token_id,
|
||||
"im_start_id": tokenizer.im_start_id,
|
||||
"im_end_id": tokenizer.im_end_id,
|
||||
"im_start_id": im_start_id,
|
||||
"im_end_id": im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
}
|
||||
|
||||
@@ -135,11 +135,17 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
processor_output["im_end_id"] = self.eoi_token_index
|
||||
processor_output["im_token_id"] = self.image_token_index
|
||||
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=torch.tensor(processor_output["input_ids"]),
|
||||
mm_token_id=self.image_token_index,
|
||||
)
|
||||
|
||||
# Add metadata for image processing
|
||||
processor_output["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.models.pixtral.image_processing_pixtral import (
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||
)
|
||||
@@ -12,11 +10,7 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.pixtral import PixtralVisionModel
|
||||
|
||||
|
||||
@@ -108,15 +102,21 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
)
|
||||
|
||||
if "pixel_values" in processor_output:
|
||||
input_ids = processor_output["input_ids"].view(-1)
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.image_token_id,
|
||||
)
|
||||
mm_items = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
image_sizes=processor_output["image_sizes"],
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
input_ids = processor_output["input_ids"].view(-1).tolist()
|
||||
input_ids = input_ids.tolist()
|
||||
processor_output.update(
|
||||
input_ids=input_ids,
|
||||
mm_items=mm_items,
|
||||
|
||||
@@ -135,6 +135,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
images=None if images_are_preprocessed else base_output.images,
|
||||
)
|
||||
input_ids = ret["input_ids"].flatten().tolist()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id
|
||||
)
|
||||
image_grid_thw = None
|
||||
video_grid_thw = None # TODO
|
||||
items = []
|
||||
@@ -175,6 +178,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
image_grid_thws=image_grid_thw,
|
||||
video_grid_thws=video_grid_thw,
|
||||
precomputed_features=precomputed_features,
|
||||
image_offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -197,6 +197,7 @@ class MultimodalDataItem:
|
||||
|
||||
audio_features: Union[torch.Tensor, np.ndarray] = None
|
||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
||||
|
||||
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||
|
||||
@@ -1097,7 +1098,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
else:
|
||||
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
|
||||
|
||||
assert len(self.out_cache_loc) == self.extend_num_tokens
|
||||
assert (
|
||||
len(self.out_cache_loc) == self.extend_num_tokens
|
||||
), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
|
||||
|
||||
def prepare_for_extend(self):
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
|
||||
@@ -102,6 +102,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
MultimodalInputs,
|
||||
@@ -2282,6 +2283,10 @@ def run_scheduler_process(
|
||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||
|
||||
embedding_cache_size = 100
|
||||
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
|
||||
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
|
||||
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
||||
# Create a scheduler and run the event loop
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
||||
|
||||
45
python/sglang/srt/mem_cache/multimodal_cache.py
Normal file
45
python/sglang/srt/mem_cache/multimodal_cache.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiModalCache:
|
||||
"""MultiModalCache is used to store vlm encoder results"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int,
|
||||
):
|
||||
self.max_size = max_size
|
||||
self.mm_cache: Dict[int, torch.Tensor] = {}
|
||||
self.current_size = 0
|
||||
|
||||
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
|
||||
if mm_hash in self.mm_cache:
|
||||
return True
|
||||
data_size = self._get_tensor_size(embedding)
|
||||
if self.current_size + data_size > self.max_size:
|
||||
return False
|
||||
self.mm_cache[mm_hash] = embedding
|
||||
self.current_size += data_size
|
||||
return True
|
||||
|
||||
def get(self, mm_hash: int) -> torch.Tensor:
|
||||
return self.mm_cache.get(mm_hash)
|
||||
|
||||
def free(self, mm_hash: int) -> bool:
|
||||
if mm_hash not in self.mm_cache:
|
||||
return False
|
||||
old_embedding = self.mm_cache.pop(mm_hash)
|
||||
self.current_size -= self._get_tensor_size(old_embedding)
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
self.mm_cache.clear()
|
||||
self.current_size = 0
|
||||
|
||||
def _get_tensor_size(self, embedding: torch.Tensor):
|
||||
return embedding.element_size() * embedding.numel()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.mm_cache)
|
||||
@@ -166,6 +166,9 @@ class ModelRunner:
|
||||
self.is_draft_worker = is_draft_worker
|
||||
self.is_generation = model_config.is_generation
|
||||
self.is_multimodal = model_config.is_multimodal
|
||||
self.is_multimodal_chunked_prefill_supported = (
|
||||
model_config.is_multimodal_chunked_prefill_supported
|
||||
)
|
||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
@@ -389,12 +392,15 @@ class ModelRunner:
|
||||
if self.is_multimodal:
|
||||
self.mem_fraction_static *= 0.90
|
||||
logger.info(
|
||||
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model."
|
||||
)
|
||||
server_args.chunked_prefill_size = -1
|
||||
logger.info(
|
||||
"Automatically turn off --chunked-prefill-size for multimodal model."
|
||||
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
||||
f"because this is a multimodal model."
|
||||
)
|
||||
if not self.is_multimodal_chunked_prefill_supported:
|
||||
server_args.chunked_prefill_size = -1
|
||||
logger.info(
|
||||
f"Automatically turn of --chunked-prefill-size as it is not supported for "
|
||||
f"{self.model_config.hf_config.model_type}"
|
||||
)
|
||||
|
||||
if not self.use_mla_backend:
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
|
||||
@@ -1826,22 +1826,12 @@ class MiniCPMO(MiniCPMBaseModel):
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
|
||||
mm_input = forward_batch.merge_mm_inputs()
|
||||
placeholder_token_ids = (
|
||||
([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
|
||||
if forward_batch.contains_mm_inputs()
|
||||
else []
|
||||
)
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.llm,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
audio_data_embedding_func=self.get_audio_feature,
|
||||
placeholder_tokens={
|
||||
Modality.IMAGE: placeholder_token_ids,
|
||||
Modality.AUDIO: placeholder_token_ids,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user