vlm: enable radix cache for qwen-vl models (#5349)

Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Mick
2025-04-24 12:35:05 +09:00
committed by GitHub
parent 7d0edf3cae
commit c998d04b46
26 changed files with 429 additions and 331 deletions

View File

@@ -10,12 +10,13 @@ import torch
from torch import nn
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import print_warning_once
from sglang.srt.utils import flatten_nested_list, print_warning_once
logger = logging.getLogger(__name__)
@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return padded_ids
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be represented as repetitions of a single token
e.g. <image><image>....<image>, or <audio><audio>...<audio>
"""
def __init__(self, image_token_id: torch.Tensor) -> None:
self.image_token_id = image_token_id
def __init__(self, token_ids: List[int]) -> None:
self.token_ids = token_ids
def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]:
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
This function will replace the data-tokens in between with pad_values accordingly
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
"""
pad_values = [item.pad_value for item in mm_inputs.mm_items]
assert len(pad_values) != 0
if not pad_values:
# No multimodal items, return original input_ids
return input_ids
if not input_ids:
return []
input_ids_tensor = torch.tensor(input_ids)
mask = torch.isin(input_ids_tensor, self.image_token_id)
device = input_ids_tensor.device
token_ids_tensor = torch.tensor(self.token_ids, device=device)
mask = torch.isin(input_ids_tensor, token_ids_tensor)
num_image_tokens = mask.sum().item()
repeated_pad_values = torch.tensor(pad_values).repeat(
num_image_tokens // len(pad_values) + 1
)[:num_image_tokens]
if not mask.any():
# No tokens match token_ids, return original input_ids
return input_ids
input_ids_tensor[mask] = repeated_pad_values
return input_ids_tensor.tolist()
# Find contiguous regions
padded_mask = torch.cat(
(
torch.tensor([False], device=device),
mask,
torch.tensor([False], device=device),
)
)
# Find indices where the mask value changes
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
# Start indices are where False changes to True
starts = diff_indices[::2]
# End indices are where True changes to False (exclusive index)
ends = diff_indices[1::2]
# Check if the number of regions matches the number of pad values
if len(starts) != len(pad_values):
# Maybe log a warning here?
num_regions = len(starts)
num_pad_values = len(pad_values)
if num_regions > 0 and num_pad_values > 0:
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
:num_regions
]
else: # If no regions or no pad_values, this loop won't run anyway.
pad_values = [] # Ensure pad_values is empty if starts is empty
# Create a copy to modify
output_ids_tensor = input_ids_tensor.clone()
# Replace tokens in each region with the corresponding pad value
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
for i in range(min(len(starts), len(pad_values))):
start_idx = starts[i]
end_idx = ends[i]
pad_value = pad_values[i]
if pad_value is not None: # Ensure pad_value is not None before assignment
output_ids_tensor[start_idx:end_idx] = pad_value
else:
logger.warning(f"Skipping region {i} due to None pad_value.")
return output_ids_tensor.tolist()
def get_embedding_and_mask(
@@ -150,7 +200,6 @@ def get_embedding_and_mask(
).unsqueeze(-1)
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text."
@@ -190,13 +239,13 @@ def embed_mm_inputs(
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
placeholder_token_ids: List[int] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
) -> Optional[torch.Tensor]:
"""
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
Args:
placeholder_token_ids: denoting the token of multimodal data in input_ids.
placeholder_tokens: denoting the token of multimodal data in input_ids.
If none, the pad_values of multimodal items are used
Returns:
@@ -208,9 +257,17 @@ def embed_mm_inputs(
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# we assume that multimodal data are represented with its pad_values in input_ids
placeholder_token_ids = placeholder_token_ids or [
item.pad_value for item in mm_inputs.mm_items
]
# See `pad_input_ids` for more detail
# if placeholder_tokens is specified
if placeholder_tokens is not None:
placeholder_token_ids = flatten_nested_list(
[placeholder_token for placeholder_token in placeholder_tokens.values()]
)
else:
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
assert isinstance(placeholder_token_ids[0], int)
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
@@ -233,7 +290,7 @@ def embed_mm_inputs(
using_all_items = False
if len(appearing_items) == 0:
# This happens mostly when arg placeholder_token_ids is passed
logger.warning_once(
logger.warning(
"No multimodal data item's pad value exist in placeholder ids. Using all items"
)
using_all_items = True
@@ -253,7 +310,8 @@ def embed_mm_inputs(
data_embedding_func=image_data_embedding_func,
embedding_items=items,
placeholder_tensor=(
placeholder_tensor
# use the specified modality token to identify the location to embed
placeholder_tokens[Modality.IMAGE]
if using_all_items
else torch.tensor(
[item.pad_value for item in items],
@@ -275,7 +333,7 @@ def embed_mm_inputs(
data_embedding_func=audio_data_embedding_func,
embedding_items=items,
placeholder_tensor=(
placeholder_tensor
placeholder_tokens[Modality.AUDIO]
if using_all_items
else torch.tensor(
[item.pad_value for item in items],
@@ -296,7 +354,7 @@ def embed_mm_inputs(
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
# 4. Scatter embeddings into input embedding
for embedding, mask in zip(embeddings, masks):
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(
@@ -316,7 +374,7 @@ def general_mm_embed_routine(
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
placeholder_token_ids: List[int] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -328,7 +386,6 @@ def general_mm_embed_routine(
audio_data_embedding_func : the function returning the image embedding
Returns:
inputs_embedding
forwarded hidden states
"""
@@ -346,9 +403,9 @@ def general_mm_embed_routine(
input_embedding=embed_tokens,
image_data_embedding_func=image_data_embedding_func,
audio_data_embedding_func=audio_data_embedding_func,
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens=placeholder_tokens,
)
# once used, mm_inputs is useless
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch.mm_inputs = None
else: