vlm: enable radix cache for qwen-vl models (#5349)

Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Mick
2025-04-24 12:35:05 +09:00
committed by GitHub
parent 7d0edf3cae
commit c998d04b46
26 changed files with 429 additions and 331 deletions

View File

@@ -12,7 +12,7 @@ from sglang.srt.configs.deepseekvl2 import (
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens,
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
@@ -249,8 +249,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
helper = MultiModalityDataPaddingPatternImageTokens(
image_token_id=image_inputs.im_token_id
helper = MultiModalityDataPaddingPatternMultimodalTokens(
[image_inputs.im_token_id]
)
return helper.pad_input_tokens(input_ids, image_inputs)

View File

@@ -43,6 +43,7 @@ from sglang.srt.managers.mm_utils import (
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
@@ -1834,7 +1835,10 @@ class MiniCPMO(MiniCPMBaseModel):
language_model=self.llm,
image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens={
Modality.IMAGE: placeholder_token_ids,
Modality.AUDIO: placeholder_token_ids,
},
positions=positions,
)
return hidden_states

View File

@@ -10,7 +10,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens,
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
@@ -53,7 +53,7 @@ class Llama4ForConditionalGeneration(nn.Module):
# Get all special token IDs
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(

View File

@@ -49,7 +49,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
@@ -488,11 +488,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = mm_inputs.im_start_id
im_end_id: int = mm_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:

View File

@@ -42,7 +42,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
@@ -490,15 +490,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = mm_inputs.im_start_id
im_end_id: int = mm_inputs.im_end_id
im_token_id: int = mm_inputs.im_token_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: