[VLM] Support chunk prefill for VLM (#6355)
Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
@@ -116,6 +116,10 @@ class ModelConfig:
|
|||||||
self.is_audio_model = enable_multimodal and is_audio_model(
|
self.is_audio_model = enable_multimodal and is_audio_model(
|
||||||
self.hf_config.architectures
|
self.hf_config.architectures
|
||||||
)
|
)
|
||||||
|
self.is_multimodal_chunked_prefill_supported = (
|
||||||
|
enable_multimodal
|
||||||
|
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
||||||
|
)
|
||||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
|
|
||||||
@@ -574,6 +578,21 @@ def is_encoder_decoder_model(model_architectures: List[str]):
|
|||||||
return "MllamaForConditionalGeneration" in model_architectures
|
return "MllamaForConditionalGeneration" in model_architectures
|
||||||
|
|
||||||
|
|
||||||
|
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
|
||||||
|
"""Check if chunked prefill is supported for a MultiModal model."""
|
||||||
|
unsupported = [
|
||||||
|
"Grok1VForCausalLM",
|
||||||
|
"Grok1AForCausalLM",
|
||||||
|
"LlavaLlamaForCausalLM",
|
||||||
|
"MllamaForConditionalGeneration",
|
||||||
|
"CLIPModel",
|
||||||
|
]
|
||||||
|
if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||||
if scale <= 1:
|
if scale <= 1:
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|||||||
@@ -16,10 +16,15 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
MultimodalInputs,
|
MultimodalInputs,
|
||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
||||||
|
from sglang.utils import logger
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
|
||||||
|
# to ensure consistent logging behavior across the codebase. This prevents issues with log
|
||||||
|
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
||||||
|
# in the console when multimodal support is enabled.
|
||||||
|
|
||||||
|
|
||||||
class MultiModalityDataPaddingPattern:
|
class MultiModalityDataPaddingPattern:
|
||||||
@@ -189,26 +194,137 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|||||||
return output_ids_tensor.tolist()
|
return output_ids_tensor.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
embedding_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_embedding_cache(max_size: int):
|
||||||
|
global embedding_cache
|
||||||
|
embedding_cache = MultiModalCache(max_size)
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int:
|
||||||
|
hash_list = [item.hash for item in embedding_items]
|
||||||
|
return hash(tuple(hash_list))
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_chunk(
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
extend_prefix_len: int,
|
||||||
|
extend_seq_len: int,
|
||||||
|
items_offset: List[Tuple[int, int]],
|
||||||
|
) -> Tuple[torch.Tensor, int, int]:
|
||||||
|
"""
|
||||||
|
Extract a chunk of embeddings based on the specified prefix length, sequence length, and offset ranges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: The full embedding tensor to extract a chunk from
|
||||||
|
extend_prefix_len: The starting position (prefix length) for extraction
|
||||||
|
extend_seq_len: The number of tokens to extract
|
||||||
|
items_offset: List of [start, end] offset ranges for multimodal items in the input sequence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- The extracted embedding chunk as a tensor
|
||||||
|
- The start index used for extraction
|
||||||
|
- The end index used for extraction
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If there's no overlap between the requested range and the offset ranges,
|
||||||
|
an empty tensor is returned with zeros for start and end indices.
|
||||||
|
"""
|
||||||
|
start_index, end_index = 0, 0
|
||||||
|
extend_start_index = extend_prefix_len
|
||||||
|
extend_end_index = extend_prefix_len + extend_seq_len - 1
|
||||||
|
|
||||||
|
for start, end in items_offset:
|
||||||
|
if extend_start_index >= start and extend_start_index <= end:
|
||||||
|
start_index += extend_start_index - start
|
||||||
|
elif extend_start_index > end:
|
||||||
|
start_index += end - start + 1
|
||||||
|
|
||||||
|
if extend_end_index >= start and extend_end_index <= end:
|
||||||
|
end_index += extend_end_index - start + 1
|
||||||
|
elif extend_end_index > end:
|
||||||
|
end_index += end - start + 1
|
||||||
|
# some models embedding is 3-dim, reshape it to 2-dim
|
||||||
|
embedding = embedding.reshape(-1, embedding.shape[-1])
|
||||||
|
embedding_chunk = embedding[start_index:end_index]
|
||||||
|
return embedding_chunk, start_index, end_index
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_and_mask(
|
def get_embedding_and_mask(
|
||||||
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
||||||
embedding_items: List[MultimodalDataItem],
|
embedding_items: List[MultimodalDataItem],
|
||||||
placeholder_tensor: torch.Tensor,
|
placeholder_tensor: torch.Tensor,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
):
|
items_size: List[int],
|
||||||
|
prefix_length: List[int],
|
||||||
|
extend_length: List[int],
|
||||||
|
items_offset_list: List[List[Tuple[int, int]]],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Get the multimodal embedding and its mask from input_ids
|
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_embedding_func: Function that generates embeddings for multimodal items
|
||||||
|
embedding_items: List of multimodal items to embed
|
||||||
|
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
|
||||||
|
input_ids: The input token IDs tensor
|
||||||
|
items_size: Cumulative sizes of multimodal items per request
|
||||||
|
prefix_length: Prefix lengths for each request
|
||||||
|
extend_length: Sequence lengths for each request
|
||||||
|
items_offset_list: List of offset ranges for multimodal items in each request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- The generated embeddings tensor
|
||||||
|
- A boolean mask tensor indicating where these embeddings should be placed
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the number of multimodal tokens in input_ids doesn't match
|
||||||
|
the number of tokens in the generated embeddings
|
||||||
"""
|
"""
|
||||||
# 1. Get the embedding
|
# 1. Get the embedding
|
||||||
embedding = data_embedding_func(embedding_items)
|
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
||||||
|
embedding_list = []
|
||||||
|
for i in range(len(items_size) - 1):
|
||||||
|
if items_size[i] == items_size[i + 1]:
|
||||||
|
continue
|
||||||
|
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
||||||
|
items_offset = items_offset_list[i]
|
||||||
|
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
|
||||||
|
# if all items has been prefixed, we do not need to calculate embedding
|
||||||
|
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
|
||||||
|
continue
|
||||||
|
embedding_per_req = embedding_cache.get(embedding_items_hash)
|
||||||
|
if embedding_per_req is None:
|
||||||
|
embedding_per_req = data_embedding_func(embedding_items_per_req)
|
||||||
|
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
|
||||||
|
print_warning_once(
|
||||||
|
"Multimodal embedding cache is full. Consider increasing the "
|
||||||
|
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
|
||||||
|
embedding=embedding_per_req,
|
||||||
|
extend_prefix_len=prefix_length[i],
|
||||||
|
extend_seq_len=extend_length[i],
|
||||||
|
items_offset=items_offset,
|
||||||
|
)
|
||||||
|
# remove this item from cache if chunk reaches to the end
|
||||||
|
embedding_per_req_length = (
|
||||||
|
embedding_per_req.shape[0]
|
||||||
|
if embedding_per_req.dim() == 2
|
||||||
|
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
|
||||||
|
)
|
||||||
|
if end_index == embedding_per_req_length:
|
||||||
|
embedding_cache.free(embedding_items_hash)
|
||||||
|
embedding_list.append(embedding_per_req_chunk)
|
||||||
|
if len(embedding_list) == 0:
|
||||||
|
return None, None
|
||||||
|
embedding = torch.concat(embedding_list, dim=0)
|
||||||
# 2. Check the embedding
|
# 2. Check the embedding
|
||||||
if embedding.dim() == 2:
|
num_mm_tokens_in_embedding = embedding.shape[0]
|
||||||
num_mm_tokens_in_embedding = embedding.shape[0]
|
|
||||||
else:
|
|
||||||
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
|
|
||||||
|
|
||||||
# the mask of multimodal tokens from input_ids
|
|
||||||
special_multimodal_mask = torch.isin(
|
special_multimodal_mask = torch.isin(
|
||||||
input_ids,
|
input_ids,
|
||||||
placeholder_tensor,
|
placeholder_tensor,
|
||||||
@@ -222,9 +338,6 @@ def get_embedding_and_mask(
|
|||||||
"tokens from multimodal embeddings."
|
"tokens from multimodal embeddings."
|
||||||
)
|
)
|
||||||
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
||||||
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
|
||||||
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
|
|
||||||
# extend_start_loc and extend_seq_lens
|
|
||||||
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
||||||
if chunked_prefill_size != -1:
|
if chunked_prefill_size != -1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -245,7 +358,9 @@ def get_embedding_and_mask(
|
|||||||
|
|
||||||
|
|
||||||
def embed_mm_inputs(
|
def embed_mm_inputs(
|
||||||
mm_inputs: MultimodalInputs,
|
mm_inputs_list: List[MultimodalInputs],
|
||||||
|
extend_prefix_lens: List[int],
|
||||||
|
extend_seq_lens: List[int],
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
input_embedding: nn.Embedding,
|
input_embedding: nn.Embedding,
|
||||||
image_data_embedding_func: Callable[
|
image_data_embedding_func: Callable[
|
||||||
@@ -257,125 +372,133 @@ def embed_mm_inputs(
|
|||||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
|
Embed multimodal inputs and integrate them with text token embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
placeholder_tokens: denoting the token of multimodal data in input_ids.
|
mm_inputs_list: List of multimodal inputs to process
|
||||||
If none, the pad_values of multimodal items are used
|
extend_prefix_lens: Prefix lengths for each request
|
||||||
|
extend_seq_lens: Sequence lengths for each request
|
||||||
|
input_ids: Input token IDs tensor
|
||||||
|
input_embedding: Embedding layer for text tokens
|
||||||
|
image_data_embedding_func: Function to embed image data
|
||||||
|
audio_data_embedding_func: Function to embed audio data
|
||||||
|
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
final embedding: Optional[torch.Tensor]
|
Combined embedding tensor with multimodal content integrated
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if mm_inputs is None:
|
if mm_inputs_list is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
||||||
# we assume that multimodal data are represented with its pad_values in input_ids
|
# we assume that multimodal data are represented with its pad_values in input_ids
|
||||||
# See `pad_input_ids` for more detail
|
item_flatten_list = []
|
||||||
|
for mm_inputs in mm_inputs_list:
|
||||||
|
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
||||||
|
|
||||||
# if placeholder_tokens is specified
|
embeddings, masks = [], []
|
||||||
if placeholder_tokens is not None:
|
|
||||||
placeholder_token_ids = flatten_nested_list(
|
# 2. Get multimodal embedding separately
|
||||||
[placeholder_token for placeholder_token in placeholder_tokens.values()]
|
# TODO: make this more generic
|
||||||
|
# Try get image embedding if any
|
||||||
|
if (
|
||||||
|
any(True for item in item_flatten_list if item.is_image())
|
||||||
|
and image_data_embedding_func
|
||||||
|
):
|
||||||
|
items = [item for item in item_flatten_list if item.is_image()]
|
||||||
|
placeholder_tensor = torch.tensor(
|
||||||
|
[item.pad_value for item in items],
|
||||||
|
device=input_ids.device,
|
||||||
)
|
)
|
||||||
else:
|
# calculate per request items length offset
|
||||||
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
|
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
||||||
|
items_offsets = []
|
||||||
assert isinstance(placeholder_token_ids[0], int)
|
for i, mm_inputs in enumerate(mm_inputs_list):
|
||||||
|
image_items = [item for item in mm_inputs.mm_items if item.is_image()]
|
||||||
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
|
items_size[i + 1] = len(image_items)
|
||||||
|
items_offsets.append(
|
||||||
placeholder_masks = torch.isin(input_ids, placeholder_tensor)
|
flatten_nested_list(
|
||||||
|
[
|
||||||
appearing_pad_values = torch.unique(
|
item.image_offsets
|
||||||
input_ids[placeholder_masks], return_counts=False
|
for item in mm_inputs.mm_items
|
||||||
)
|
if item.is_image()
|
||||||
|
]
|
||||||
if appearing_pad_values.numel() == 0:
|
)
|
||||||
# all been prefixed
|
|
||||||
inputs_embeds = input_embedding(input_ids)
|
|
||||||
else:
|
|
||||||
appearing_items = [
|
|
||||||
item
|
|
||||||
for item in mm_inputs.mm_items
|
|
||||||
if item.pad_value is not None and item.pad_value in appearing_pad_values
|
|
||||||
]
|
|
||||||
|
|
||||||
using_all_items = False
|
|
||||||
if len(appearing_items) == 0:
|
|
||||||
# This happens mostly when arg placeholder_token_ids is passed
|
|
||||||
logger.warning(
|
|
||||||
"No multimodal data item's pad value exist in placeholder ids. Using all items"
|
|
||||||
)
|
)
|
||||||
using_all_items = True
|
items_size = torch.cumsum(items_size, dim=0).tolist()
|
||||||
appearing_items = mm_inputs.mm_items
|
|
||||||
|
|
||||||
embeddings, masks = [], []
|
embedding, mask = get_embedding_and_mask(
|
||||||
|
data_embedding_func=image_data_embedding_func,
|
||||||
|
embedding_items=items,
|
||||||
|
placeholder_tensor=placeholder_tensor,
|
||||||
|
input_ids=input_ids,
|
||||||
|
items_size=items_size,
|
||||||
|
prefix_length=extend_prefix_lens,
|
||||||
|
extend_length=extend_seq_lens,
|
||||||
|
items_offset_list=items_offsets,
|
||||||
|
)
|
||||||
|
embeddings += [embedding]
|
||||||
|
masks += [mask]
|
||||||
|
|
||||||
# 2. Get multimodal embedding separately
|
# Try get audio embedding if any
|
||||||
# TODO: make this more generic
|
if (
|
||||||
# Try get image embedding if any
|
any(True for item in item_flatten_list if item.is_audio())
|
||||||
if (
|
and audio_data_embedding_func
|
||||||
any(True for item in appearing_items if item.is_image())
|
):
|
||||||
and image_data_embedding_func
|
items = [item for item in item_flatten_list if item.is_audio()]
|
||||||
):
|
placeholder_tensor = torch.tensor(
|
||||||
items = [item for item in appearing_items if item.is_image()]
|
[item.pad_value for item in items],
|
||||||
embedding, mask = get_embedding_and_mask(
|
device=input_ids.device,
|
||||||
data_embedding_func=image_data_embedding_func,
|
)
|
||||||
embedding_items=items,
|
items_offsets = []
|
||||||
placeholder_tensor=(
|
# calculate per request items length offset
|
||||||
# use the specified modality token to identify the location to embed
|
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
||||||
placeholder_tokens[Modality.IMAGE]
|
for i, mm_inputs in enumerate(mm_inputs_list):
|
||||||
if using_all_items
|
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
|
||||||
else torch.tensor(
|
items_size[i + 1] = len(audio_items)
|
||||||
[item.pad_value for item in items],
|
items_offsets.append(
|
||||||
device=input_ids.device,
|
flatten_nested_list(
|
||||||
)
|
[
|
||||||
),
|
item.audio_offsets
|
||||||
input_ids=input_ids,
|
for item in mm_inputs.mm_items
|
||||||
|
if item.is_audio()
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
embeddings += [embedding]
|
items_size = torch.cumsum(items_size, dim=0)
|
||||||
masks += [mask]
|
|
||||||
|
|
||||||
# Try get audio embedding if any
|
embedding, mask = get_embedding_and_mask(
|
||||||
if (
|
data_embedding_func=audio_data_embedding_func,
|
||||||
any(True for item in appearing_items if item.is_audio())
|
embedding_items=items,
|
||||||
and audio_data_embedding_func
|
placeholder_tensor=placeholder_tensor,
|
||||||
):
|
input_ids=input_ids,
|
||||||
items = [item for item in appearing_items if item.is_audio()]
|
items_size=items_size,
|
||||||
embedding, mask = get_embedding_and_mask(
|
prefix_length=extend_prefix_lens,
|
||||||
data_embedding_func=audio_data_embedding_func,
|
extend_length=extend_seq_lens,
|
||||||
embedding_items=items,
|
items_offset_list=items_offsets,
|
||||||
placeholder_tensor=(
|
)
|
||||||
placeholder_tokens[Modality.AUDIO]
|
embeddings += [embedding]
|
||||||
if using_all_items
|
masks += [mask]
|
||||||
else torch.tensor(
|
|
||||||
[item.pad_value for item in items],
|
|
||||||
device=input_ids.device,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
input_ids=input_ids,
|
|
||||||
)
|
|
||||||
embeddings += [embedding]
|
|
||||||
masks += [mask]
|
|
||||||
|
|
||||||
# 3. Get input embeddings
|
# 3. Get input embeddings
|
||||||
vocab_size = input_embedding.num_embeddings
|
vocab_size = input_embedding.num_embeddings
|
||||||
# Important: clamp after getting original multimodal regions
|
# Important: clamp after getting original multimodal regions
|
||||||
# Clamp input ids. This is because the input_ids for the multimodal tokens are
|
# Clamp input ids. This is because the input_ids for the multimodal tokens are
|
||||||
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
|
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
|
||||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||||
input_ids.clamp_(min=0, max=vocab_size - 1)
|
input_ids.clamp_(min=0, max=vocab_size - 1)
|
||||||
inputs_embeds = input_embedding(input_ids)
|
inputs_embeds = input_embedding(input_ids)
|
||||||
|
|
||||||
# 4. Scatter embeddings into input embedding
|
# 4. scatter embeddings into input embedding
|
||||||
for embedding, mask in zip(embeddings, masks):
|
for embedding, mask in zip(embeddings, masks):
|
||||||
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
if embedding is None or mask is None:
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(
|
continue
|
||||||
mask,
|
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
inputs_embeds = inputs_embeds.masked_scatter(
|
||||||
)
|
mask,
|
||||||
|
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
||||||
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
@@ -393,16 +516,19 @@ def general_mm_embed_routine(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
A general wrapper function to get final input embeds from multimodal models with a language model as causal model
|
Process multimodal inputs and forward through language model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
input_ids: Input token IDs tensor
|
||||||
image_data_embedding_func : the function returning the image embedding
|
forward_batch: Batch information for model forward pass
|
||||||
audio_data_embedding_func : the function returning the image embedding
|
language_model: Base language model to use
|
||||||
|
image_data_embedding_func: Function to embed image data
|
||||||
Returns:
|
audio_data_embedding_func: Function to embed audio data
|
||||||
forwarded hidden states
|
placeholder_tokens: Token IDs for multimodal placeholders
|
||||||
|
**kwargs: Additional arguments passed to language model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hidden states from language model forward pass
|
||||||
"""
|
"""
|
||||||
assert hasattr(language_model, "get_input_embeddings")
|
assert hasattr(language_model, "get_input_embeddings")
|
||||||
embed_tokens = language_model.get_input_embeddings()
|
embed_tokens = language_model.get_input_embeddings()
|
||||||
@@ -410,9 +536,23 @@ def general_mm_embed_routine(
|
|||||||
not forward_batch.forward_mode.is_decode()
|
not forward_batch.forward_mode.is_decode()
|
||||||
and forward_batch.contains_mm_inputs()
|
and forward_batch.contains_mm_inputs()
|
||||||
):
|
):
|
||||||
mm_input = forward_batch.merge_mm_inputs()
|
mm_inputs_list = [
|
||||||
|
mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None
|
||||||
|
]
|
||||||
|
extend_prefix_lens = [
|
||||||
|
prefix_len
|
||||||
|
for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
if forward_batch.mm_inputs[i] is not None
|
||||||
|
]
|
||||||
|
extend_seq_lens = [
|
||||||
|
seq_len
|
||||||
|
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
|
||||||
|
if forward_batch.mm_inputs[i] is not None
|
||||||
|
]
|
||||||
inputs_embeds = embed_mm_inputs(
|
inputs_embeds = embed_mm_inputs(
|
||||||
mm_inputs=mm_input,
|
mm_inputs_list=mm_inputs_list,
|
||||||
|
extend_prefix_lens=extend_prefix_lens,
|
||||||
|
extend_seq_lens=extend_seq_lens,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
input_embedding=embed_tokens,
|
input_embedding=embed_tokens,
|
||||||
image_data_embedding_func=image_data_embedding_func,
|
image_data_embedding_func=image_data_embedding_func,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import multiprocessing as mp
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -343,6 +343,33 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
out.normalize()
|
out.normalize()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_mm_items_offset(
|
||||||
|
input_ids: torch.Tensor, mm_token_id: int
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
Get a set of range for mm_items from input_ids
|
||||||
|
Example:
|
||||||
|
input_ids = [1, 2, 3, 3, 3, 4, 3, 3]
|
||||||
|
mm_token_id = 3
|
||||||
|
return result = [(2,4),(6,7)]
|
||||||
|
"""
|
||||||
|
mask = input_ids == mm_token_id
|
||||||
|
|
||||||
|
start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
|
||||||
|
end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
|
||||||
|
|
||||||
|
return list(zip(start_positions.tolist(), end_positions.tolist()))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_mm_items_offset_by_pair(
|
||||||
|
input_ids: torch.Tensor, mm_start_id: int, mm_end_id: int
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
indices_start = (input_ids == mm_start_id).nonzero(as_tuple=True)[0] + 1
|
||||||
|
indices_end = (input_ids == mm_end_id).nonzero(as_tuple=True)[0] - 1
|
||||||
|
|
||||||
|
return list(zip(indices_start.tolist(), indices_end.tolist()))
|
||||||
|
|
||||||
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
|
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
|
||||||
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
|
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
|
||||||
if not mm_inputs:
|
if not mm_inputs:
|
||||||
|
|||||||
@@ -70,8 +70,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|||||||
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
||||||
|
|
||||||
items = []
|
items = []
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
image_offsets = self.get_mm_items_offset(
|
||||||
|
input_ids=input_ids, mm_token_id=self._processor.image_token_id
|
||||||
|
)
|
||||||
item = MultimodalDataItem(
|
item = MultimodalDataItem(
|
||||||
pixel_values=res["images"],
|
pixel_values=res["images"],
|
||||||
|
image_offsets=image_offsets,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
image_emb_mask=images_seq_mask,
|
image_emb_mask=images_seq_mask,
|
||||||
image_spatial_crop=batched_images_spatial_crop,
|
image_spatial_crop=batched_images_spatial_crop,
|
||||||
@@ -80,6 +85,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"mm_items": items,
|
"mm_items": items,
|
||||||
"input_ids": res["input_ids"].tolist(),
|
"input_ids": input_ids.tolist(),
|
||||||
"im_token_id": self._processor.image_token_id,
|
"im_token_id": self._processor.image_token_id,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
items = []
|
items = []
|
||||||
|
input_ids = ret["input_ids"].flatten()
|
||||||
|
image_offsets = self.get_mm_items_offset(
|
||||||
|
input_ids=input_ids,
|
||||||
|
mm_token_id=self.hf_config.image_token_index,
|
||||||
|
)
|
||||||
for i, image in enumerate(base_output.images):
|
for i, image in enumerate(base_output.images):
|
||||||
if images_are_preprocessed:
|
if images_are_preprocessed:
|
||||||
pixel_values = image.pixel_values
|
pixel_values = image.pixel_values
|
||||||
@@ -73,12 +78,13 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
precomputed_features=precomputed_features,
|
precomputed_features=precomputed_features,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
|
image_offsets=image_offsets[i],
|
||||||
)
|
)
|
||||||
items += [item]
|
items += [item]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"mm_items": items,
|
"mm_items": items,
|
||||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
"input_ids": input_ids.tolist(),
|
||||||
"im_start_id": self.IM_START_TOKEN_ID,
|
"im_start_id": self.IM_START_TOKEN_ID,
|
||||||
"im_end_id": self.IM_END_TOKEN_ID,
|
"im_end_id": self.IM_END_TOKEN_ID,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -209,7 +209,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
pixel_values = torch.cat(pixel_values, dim=0)
|
pixel_values = torch.cat(pixel_values, dim=0)
|
||||||
items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
|
|
||||||
|
|
||||||
for idx, num_patches in enumerate(num_patches_list):
|
for idx, num_patches in enumerate(num_patches_list):
|
||||||
image_tokens = (
|
image_tokens = (
|
||||||
@@ -220,10 +219,21 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|||||||
input_text = input_text.replace("<image>", image_tokens, 1)
|
input_text = input_text.replace("<image>", image_tokens, 1)
|
||||||
|
|
||||||
tokenizer = self._processor
|
tokenizer = self._processor
|
||||||
|
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
|
||||||
|
image_offsets = self.get_mm_items_offset(
|
||||||
|
input_ids=input_ids,
|
||||||
|
mm_token_id=self.img_context_token_id,
|
||||||
|
)
|
||||||
|
items = [
|
||||||
|
MultimodalDataItem(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
image_offsets=image_offsets,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
|
"input_ids": input_ids.tolist(),
|
||||||
.flatten()
|
|
||||||
.tolist(),
|
|
||||||
"mm_items": items,
|
"mm_items": items,
|
||||||
"im_start_id": self.img_start_token_id,
|
"im_start_id": self.img_start_token_id,
|
||||||
"im_end_id": self.img_end_token_id,
|
"im_end_id": self.img_end_token_id,
|
||||||
|
|||||||
@@ -45,15 +45,21 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|||||||
prompt=base_out.input_text,
|
prompt=base_out.input_text,
|
||||||
images=images,
|
images=images,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_ids = res["input_ids"].flatten()
|
||||||
|
image_offsets = self.get_mm_items_offset(
|
||||||
|
input_ids=input_ids, mm_token_id=processor.image_id
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"mm_items": [
|
"mm_items": [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=res["pixel_values"],
|
pixel_values=res["pixel_values"],
|
||||||
image_emb_mask=res["images_emb_mask"],
|
image_emb_mask=res["images_emb_mask"],
|
||||||
|
image_offsets=image_offsets,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
"input_ids": res["input_ids"].flatten().tolist(),
|
"input_ids": input_ids.tolist(),
|
||||||
"im_start_id": processor.image_start_id,
|
"im_start_id": processor.image_start_id,
|
||||||
"im_end_id": processor.image_end_id,
|
"im_end_id": processor.image_end_id,
|
||||||
"im_token_id": processor.image_id,
|
"im_token_id": processor.image_id,
|
||||||
|
|||||||
@@ -1,10 +1,5 @@
|
|||||||
import asyncio
|
|
||||||
import math
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||||
)
|
)
|
||||||
@@ -57,13 +52,19 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
|||||||
input_text=base_output.input_text,
|
input_text=base_output.input_text,
|
||||||
images=base_output.images,
|
images=base_output.images,
|
||||||
)
|
)
|
||||||
|
input_ids = ret["input_ids"].flatten()
|
||||||
|
image_offsets = self.get_mm_items_offset(
|
||||||
|
input_ids=input_ids,
|
||||||
|
mm_token_id=self.im_token_id,
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
"input_ids": input_ids.tolist(),
|
||||||
"mm_items": [
|
"mm_items": [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=ret["pixel_values"],
|
pixel_values=ret["pixel_values"],
|
||||||
image_grid_thws=ret["image_grid_hws"],
|
image_grid_thws=ret["image_grid_hws"],
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
|
image_offsets=image_offsets,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
"im_token_id": self.im_token_id,
|
"im_token_id": self.im_token_id,
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import importlib
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -69,6 +69,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
audio_start_id = tokenizer.audio_start_id
|
audio_start_id = tokenizer.audio_start_id
|
||||||
audio_end_id = tokenizer.audio_end_id
|
audio_end_id = tokenizer.audio_end_id
|
||||||
|
|
||||||
|
im_start_id = tokenizer.im_start_id
|
||||||
|
im_end_id = tokenizer.im_end_id
|
||||||
im_token_id = tokenizer.unk_id
|
im_token_id = tokenizer.unk_id
|
||||||
pixel_values = res["pixel_values"]
|
pixel_values = res["pixel_values"]
|
||||||
tgt_sizes = res["tgt_sizes"]
|
tgt_sizes = res["tgt_sizes"]
|
||||||
@@ -104,9 +106,20 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
pixel_values = pixel_values_flat
|
pixel_values = pixel_values_flat
|
||||||
|
|
||||||
items = []
|
items = []
|
||||||
|
input_ids = res["input_ids"].flatten()
|
||||||
|
image_offsets = self.get_mm_items_offset_by_pair(
|
||||||
|
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
|
||||||
|
)
|
||||||
|
slice_offsets = self.get_mm_items_offset_by_pair(
|
||||||
|
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
|
||||||
|
)
|
||||||
|
image_offsets.extend(slice_offsets)
|
||||||
|
image_offsets = sorted(image_offsets)
|
||||||
|
|
||||||
if len(pixel_values) != 0:
|
if len(pixel_values) != 0:
|
||||||
item = MultimodalDataItem(
|
item = MultimodalDataItem(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
|
image_offsets=image_offsets,
|
||||||
tgt_size=tgt_sizes_flat,
|
tgt_size=tgt_sizes_flat,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
)
|
)
|
||||||
@@ -117,21 +130,30 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
and res["audio_features"] is not None
|
and res["audio_features"] is not None
|
||||||
and len(res["audio_features"]) != 0
|
and len(res["audio_features"]) != 0
|
||||||
):
|
):
|
||||||
|
if audio_start_id is not None and audio_end_id is not None:
|
||||||
|
audio_offsets = self.get_mm_items_offset_by_pair(
|
||||||
|
input_ids=input_ids,
|
||||||
|
mm_start_id=audio_start_id,
|
||||||
|
mm_end_id=audio_end_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
audio_offsets = None
|
||||||
item = MultimodalDataItem(
|
item = MultimodalDataItem(
|
||||||
audio_features=[res["audio_features"]],
|
audio_features=[res["audio_features"]],
|
||||||
audio_feature_lens=res["audio_feature_lens"],
|
audio_feature_lens=res["audio_feature_lens"],
|
||||||
|
audio_offsets=audio_offsets,
|
||||||
modality=Modality.AUDIO,
|
modality=Modality.AUDIO,
|
||||||
)
|
)
|
||||||
items += [item]
|
items += [item]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"mm_items": items,
|
"mm_items": items,
|
||||||
"input_ids": res["input_ids"].flatten().tolist(),
|
"input_ids": input_ids.tolist(),
|
||||||
"audio_start_id": audio_start_id,
|
"audio_start_id": audio_start_id,
|
||||||
"audio_end_id": audio_end_id,
|
"audio_end_id": audio_end_id,
|
||||||
"im_token_id": im_token_id,
|
"im_token_id": im_token_id,
|
||||||
"im_start_id": tokenizer.im_start_id,
|
"im_start_id": im_start_id,
|
||||||
"im_end_id": tokenizer.im_end_id,
|
"im_end_id": im_end_id,
|
||||||
"slice_start_id": slice_start_id,
|
"slice_start_id": slice_start_id,
|
||||||
"slice_end_id": slice_end_id,
|
"slice_end_id": slice_end_id,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,11 +135,17 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|||||||
processor_output["im_end_id"] = self.eoi_token_index
|
processor_output["im_end_id"] = self.eoi_token_index
|
||||||
processor_output["im_token_id"] = self.image_token_index
|
processor_output["im_token_id"] = self.image_token_index
|
||||||
|
|
||||||
|
image_offsets = self.get_mm_items_offset(
|
||||||
|
input_ids=torch.tensor(processor_output["input_ids"]),
|
||||||
|
mm_token_id=self.image_token_index,
|
||||||
|
)
|
||||||
|
|
||||||
# Add metadata for image processing
|
# Add metadata for image processing
|
||||||
processor_output["mm_items"] = [
|
processor_output["mm_items"] = [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=processor_output["pixel_values"],
|
pixel_values=processor_output["pixel_values"],
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
|
image_offsets=image_offsets,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
from transformers.models.pixtral.image_processing_pixtral import (
|
from transformers.models.pixtral.image_processing_pixtral import (
|
||||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||||
)
|
)
|
||||||
@@ -12,11 +10,7 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
|
|||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
MultimodalSpecialTokens,
|
MultimodalSpecialTokens,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
Modality,
|
|
||||||
MultimodalDataItem,
|
|
||||||
MultimodalInputs,
|
|
||||||
)
|
|
||||||
from sglang.srt.models.pixtral import PixtralVisionModel
|
from sglang.srt.models.pixtral import PixtralVisionModel
|
||||||
|
|
||||||
|
|
||||||
@@ -108,15 +102,21 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "pixel_values" in processor_output:
|
if "pixel_values" in processor_output:
|
||||||
|
input_ids = processor_output["input_ids"].view(-1)
|
||||||
|
image_offsets = self.get_mm_items_offset(
|
||||||
|
input_ids=input_ids,
|
||||||
|
mm_token_id=self.image_token_id,
|
||||||
|
)
|
||||||
mm_items = [
|
mm_items = [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=processor_output["pixel_values"],
|
pixel_values=processor_output["pixel_values"],
|
||||||
image_sizes=processor_output["image_sizes"],
|
image_sizes=processor_output["image_sizes"],
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
|
image_offsets=image_offsets,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
input_ids = processor_output["input_ids"].view(-1).tolist()
|
input_ids = input_ids.tolist()
|
||||||
processor_output.update(
|
processor_output.update(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
mm_items=mm_items,
|
mm_items=mm_items,
|
||||||
|
|||||||
@@ -135,6 +135,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
images=None if images_are_preprocessed else base_output.images,
|
images=None if images_are_preprocessed else base_output.images,
|
||||||
)
|
)
|
||||||
input_ids = ret["input_ids"].flatten().tolist()
|
input_ids = ret["input_ids"].flatten().tolist()
|
||||||
|
image_offsets = self.get_mm_items_offset(
|
||||||
|
input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id
|
||||||
|
)
|
||||||
image_grid_thw = None
|
image_grid_thw = None
|
||||||
video_grid_thw = None # TODO
|
video_grid_thw = None # TODO
|
||||||
items = []
|
items = []
|
||||||
@@ -175,6 +178,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
image_grid_thws=image_grid_thw,
|
image_grid_thws=image_grid_thw,
|
||||||
video_grid_thws=video_grid_thw,
|
video_grid_thws=video_grid_thw,
|
||||||
precomputed_features=precomputed_features,
|
precomputed_features=precomputed_features,
|
||||||
|
image_offsets=image_offsets,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -197,6 +197,7 @@ class MultimodalDataItem:
|
|||||||
|
|
||||||
audio_features: Union[torch.Tensor, np.ndarray] = None
|
audio_features: Union[torch.Tensor, np.ndarray] = None
|
||||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||||
|
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
||||||
|
|
||||||
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||||
|
|
||||||
@@ -1097,7 +1098,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
else:
|
else:
|
||||||
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
|
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
|
||||||
|
|
||||||
assert len(self.out_cache_loc) == self.extend_num_tokens
|
assert (
|
||||||
|
len(self.out_cache_loc) == self.extend_num_tokens
|
||||||
|
), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
|
||||||
|
|
||||||
def prepare_for_extend(self):
|
def prepare_for_extend(self):
|
||||||
self.forward_mode = ForwardMode.EXTEND
|
self.forward_mode = ForwardMode.EXTEND
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
UpdateWeightsFromTensorReqOutput,
|
UpdateWeightsFromTensorReqOutput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
FINISH_ABORT,
|
FINISH_ABORT,
|
||||||
MultimodalInputs,
|
MultimodalInputs,
|
||||||
@@ -2282,6 +2283,10 @@ def run_scheduler_process(
|
|||||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||||
|
|
||||||
|
embedding_cache_size = 100
|
||||||
|
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
|
||||||
|
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
|
||||||
|
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
||||||
# Create a scheduler and run the event loop
|
# Create a scheduler and run the event loop
|
||||||
try:
|
try:
|
||||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
||||||
|
|||||||
45
python/sglang/srt/mem_cache/multimodal_cache.py
Normal file
45
python/sglang/srt/mem_cache/multimodal_cache.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalCache:
|
||||||
|
"""MultiModalCache is used to store vlm encoder results"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_size: int,
|
||||||
|
):
|
||||||
|
self.max_size = max_size
|
||||||
|
self.mm_cache: Dict[int, torch.Tensor] = {}
|
||||||
|
self.current_size = 0
|
||||||
|
|
||||||
|
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
|
||||||
|
if mm_hash in self.mm_cache:
|
||||||
|
return True
|
||||||
|
data_size = self._get_tensor_size(embedding)
|
||||||
|
if self.current_size + data_size > self.max_size:
|
||||||
|
return False
|
||||||
|
self.mm_cache[mm_hash] = embedding
|
||||||
|
self.current_size += data_size
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get(self, mm_hash: int) -> torch.Tensor:
|
||||||
|
return self.mm_cache.get(mm_hash)
|
||||||
|
|
||||||
|
def free(self, mm_hash: int) -> bool:
|
||||||
|
if mm_hash not in self.mm_cache:
|
||||||
|
return False
|
||||||
|
old_embedding = self.mm_cache.pop(mm_hash)
|
||||||
|
self.current_size -= self._get_tensor_size(old_embedding)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.mm_cache.clear()
|
||||||
|
self.current_size = 0
|
||||||
|
|
||||||
|
def _get_tensor_size(self, embedding: torch.Tensor):
|
||||||
|
return embedding.element_size() * embedding.numel()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.mm_cache)
|
||||||
@@ -166,6 +166,9 @@ class ModelRunner:
|
|||||||
self.is_draft_worker = is_draft_worker
|
self.is_draft_worker = is_draft_worker
|
||||||
self.is_generation = model_config.is_generation
|
self.is_generation = model_config.is_generation
|
||||||
self.is_multimodal = model_config.is_multimodal
|
self.is_multimodal = model_config.is_multimodal
|
||||||
|
self.is_multimodal_chunked_prefill_supported = (
|
||||||
|
model_config.is_multimodal_chunked_prefill_supported
|
||||||
|
)
|
||||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
server_args.speculative_algorithm
|
server_args.speculative_algorithm
|
||||||
)
|
)
|
||||||
@@ -389,12 +392,15 @@ class ModelRunner:
|
|||||||
if self.is_multimodal:
|
if self.is_multimodal:
|
||||||
self.mem_fraction_static *= 0.90
|
self.mem_fraction_static *= 0.90
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model."
|
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
||||||
)
|
f"because this is a multimodal model."
|
||||||
server_args.chunked_prefill_size = -1
|
|
||||||
logger.info(
|
|
||||||
"Automatically turn off --chunked-prefill-size for multimodal model."
|
|
||||||
)
|
)
|
||||||
|
if not self.is_multimodal_chunked_prefill_supported:
|
||||||
|
server_args.chunked_prefill_size = -1
|
||||||
|
logger.info(
|
||||||
|
f"Automatically turn of --chunked-prefill-size as it is not supported for "
|
||||||
|
f"{self.model_config.hf_config.model_type}"
|
||||||
|
)
|
||||||
|
|
||||||
if not self.use_mla_backend:
|
if not self.use_mla_backend:
|
||||||
server_args.disable_chunked_prefix_cache = True
|
server_args.disable_chunked_prefix_cache = True
|
||||||
|
|||||||
@@ -1826,22 +1826,12 @@ class MiniCPMO(MiniCPMBaseModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
mm_input = forward_batch.merge_mm_inputs()
|
|
||||||
placeholder_token_ids = (
|
|
||||||
([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
|
|
||||||
if forward_batch.contains_mm_inputs()
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
hidden_states = general_mm_embed_routine(
|
hidden_states = general_mm_embed_routine(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
language_model=self.llm,
|
language_model=self.llm,
|
||||||
image_data_embedding_func=self.get_image_feature,
|
image_data_embedding_func=self.get_image_feature,
|
||||||
audio_data_embedding_func=self.get_audio_feature,
|
audio_data_embedding_func=self.get_audio_feature,
|
||||||
placeholder_tokens={
|
|
||||||
Modality.IMAGE: placeholder_token_ids,
|
|
||||||
Modality.AUDIO: placeholder_token_ids,
|
|
||||||
},
|
|
||||||
positions=positions,
|
positions=positions,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -294,20 +294,24 @@ class TestOpenAIVisionServer(CustomTestCase):
|
|||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
|
|
||||||
# Add assertions to validate the video response
|
# Add assertions to validate the video response
|
||||||
assert "iPod" in video_response or "device" in video_response, video_response
|
assert (
|
||||||
|
"iPod" in video_response or "device" in video_response
|
||||||
|
), f"video_response: {video_response}, should contain 'iPod' or 'device'"
|
||||||
assert (
|
assert (
|
||||||
"man" in video_response
|
"man" in video_response
|
||||||
or "person" in video_response
|
or "person" in video_response
|
||||||
or "individual" in video_response
|
or "individual" in video_response
|
||||||
or "speaker" in video_response
|
or "speaker" in video_response
|
||||||
), video_response
|
), f"video_response: {video_response}, should either have 'man' in video_response, or 'person' in video_response, or 'individual' in video_response or 'speaker' in video_response"
|
||||||
assert (
|
assert (
|
||||||
"present" in video_response
|
"present" in video_response
|
||||||
or "examine" in video_response
|
or "examine" in video_response
|
||||||
or "display" in video_response
|
or "display" in video_response
|
||||||
or "hold" in video_response
|
or "hold" in video_response
|
||||||
)
|
), f"video_response: {video_response}, should contain 'present', 'examine', 'display', or 'hold'"
|
||||||
assert "black" in video_response or "dark" in video_response
|
assert (
|
||||||
|
"black" in video_response or "dark" in video_response
|
||||||
|
), f"video_response: {video_response}, should contain 'black' or 'dark'"
|
||||||
self.assertIsNotNone(video_response)
|
self.assertIsNotNone(video_response)
|
||||||
self.assertGreater(len(video_response), 0)
|
self.assertGreater(len(video_response), 0)
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,10 @@ from transformers import (
|
|||||||
from sglang import Engine
|
from sglang import Engine
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.conversation import generate_chat_conv
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
|
||||||
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
|
BaseMultimodalProcessor,
|
||||||
|
)
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
Modality,
|
Modality,
|
||||||
MultimodalDataItem,
|
MultimodalDataItem,
|
||||||
@@ -188,6 +191,7 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
|||||||
.eval()
|
.eval()
|
||||||
.to(cls.device)
|
.to(cls.device)
|
||||||
)
|
)
|
||||||
|
init_embedding_cache(0)
|
||||||
|
|
||||||
async def test_vlm_embedding_output(self):
|
async def test_vlm_embedding_output(self):
|
||||||
"""
|
"""
|
||||||
@@ -226,17 +230,41 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
|||||||
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
||||||
pixel_values_flat += [pixel_n]
|
pixel_values_flat += [pixel_n]
|
||||||
tgt_sizes_flat += [tgt_n]
|
tgt_sizes_flat += [tgt_n]
|
||||||
|
|
||||||
|
im_start_id, im_end_id = (
|
||||||
|
self.tokenizer.im_start_id,
|
||||||
|
self.tokenizer.im_end_id,
|
||||||
|
)
|
||||||
|
slice_start_id, slice_end_id = (
|
||||||
|
self.tokenizer.slice_start_id,
|
||||||
|
self.tokenizer.slice_end_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair(
|
||||||
|
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
|
||||||
|
)
|
||||||
|
slice_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair(
|
||||||
|
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
|
||||||
|
)
|
||||||
|
image_offsets.extend(slice_offsets)
|
||||||
|
image_offsets = sorted(image_offsets)
|
||||||
|
|
||||||
sglang_output = embed_mm_inputs(
|
sglang_output = embed_mm_inputs(
|
||||||
mm_inputs=MultimodalInputs(
|
mm_inputs_list=[
|
||||||
mm_items=[
|
MultimodalInputs(
|
||||||
MultimodalDataItem(
|
mm_items=[
|
||||||
pixel_values=pixel_values_flat,
|
MultimodalDataItem(
|
||||||
tgt_size=tgt_sizes_flat,
|
pixel_values=pixel_values_flat,
|
||||||
modality=Modality.IMAGE,
|
image_offsets=image_offsets,
|
||||||
pad_value=self.processor.tokenizer.unk_token_id,
|
tgt_size=tgt_sizes_flat,
|
||||||
)
|
modality=Modality.IMAGE,
|
||||||
]
|
pad_value=self.processor.tokenizer.unk_token_id,
|
||||||
),
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
extend_prefix_lens=[0],
|
||||||
|
extend_seq_lens=[input_ids.shape[0]],
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
input_embedding=model.get_input_embeddings(),
|
input_embedding=model.get_input_embeddings(),
|
||||||
image_data_embedding_func=model.get_image_feature,
|
image_data_embedding_func=model.get_image_feature,
|
||||||
|
|||||||
Reference in New Issue
Block a user