vlm: enable radix cache for qwen-vl models (#5349)
Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -49,7 +49,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
@@ -488,11 +488,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = mm_inputs.im_start_id
|
||||
im_end_id: int = mm_inputs.im_end_id
|
||||
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
im_token_id: int = mm_inputs.im_token_id
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user