vlm: enable radix cache for qwen-vl models (#5349)

Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Mick
2025-04-24 12:35:05 +09:00
committed by GitHub
parent 7d0edf3cae
commit c998d04b46
26 changed files with 429 additions and 331 deletions

View File

@@ -8,6 +8,7 @@ from typing import List, Optional
import numpy as np
import PIL
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality
@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC):
@abstractmethod
async def process_mm_data_async(
self, image_data, input_text, max_req_input_len, **kwargs
self,
image_data,
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
pass
@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC):
from decord import VideoReader, cpu
# Before processing inputs
if not image_data or len(image_data) == 0:
return []
estimated_frames_list = []
for image in image_data:
if isinstance(image, str) and image.startswith("video:"):
@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if image_data is None:
image_data = []
if isinstance(multimodal_tokens.image_token, int):
multimodal_tokens.image_token = (
self._processor.tokenizer.convert_ids_to_tokens(
@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC):
prompt = self._processor.tokenizer.decode(prompt)
else:
prompt = prompt
assert isinstance(prompt, str)
if return_text:
import re

View File

@@ -16,6 +16,7 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from typing import List, Union
import torch
@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
self.IMAGE_TOKEN = "<image>"
async def process_mm_data_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs
):
if not image_data:
return None
@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_ids,
input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,

View File

@@ -1,7 +1,5 @@
from typing import List, Union
from transformers.utils import logging
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# will be removed in the future
logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
input_text,
request_obj,
max_req_input_len,
*args,
@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
prompt=input_ids,
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,

View File

@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
input_text,
request_obj,
max_req_input_len,
**kwargs,
@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
processor = self._processor
base_out = self.load_mm_data(
prompt=input_ids,
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=processor.image_token

View File

@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data = [audio_data]
base_output = self.load_mm_data(
prompt=input_ids,
prompt=input_text,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,

View File

@@ -5,6 +5,7 @@ from typing import List, Union
import torch
from PIL import Image
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.image_token_id = hf_config.image_token_id
self.video_token_id = hf_config.video_token_id
self.vision_start_token_id = hf_config.vision_start_token_id
self.vision_end_token_id = hf_config.vision_end_token_id
self.NUM_TOKEN_PER_FRAME = 770
self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 4 * 28 * 28
@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
prompt,
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
prompt=prompt,
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def resize_image_async(image):
return resize_image(image)
resize_tasks = [resize_image_async(image) for image in base_output.images]
resized_images = await asyncio.gather(*resize_tasks)
if base_output.images:
resize_tasks = [resize_image_async(image) for image in base_output.images]
base_output.images = await asyncio.gather(*resize_tasks)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=resized_images,
images=base_output.images,
)
image_grid_thws = torch.concat([ret["image_grid_thw"]])
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"mm_items": [
items = []
input_ids = ret["input_ids"].flatten().tolist()
if "pixel_values" in ret:
items += [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=image_grid_thws,
image_grid_thws=torch.concat([ret["image_grid_thw"]]),
# TODO
video_grid_thws=None,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
modality=Modality.IMAGE,
)
],
]
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.image_token_id,
video_token_id=self.video_token_id,
vision_start_token_id=self.vision_start_token_id,
model_type=self.hf_config.model_type,
tokens_per_second=getattr(
self.hf_config.vision_config, "tokens_per_second", None
),
input_ids=torch.tensor(input_ids).unsqueeze(0),
image_grid_thw=ret.get("image_grid_thw", None),
video_grid_thw=ret.get("video_grid_thw", None),
second_per_grid_ts=ret.get("second_per_grid_ts", None),
)
mrope_positions = mrope_positions.squeeze(1)
return {
"input_ids": input_ids,
"mm_items": items,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id,
"video_token_id": self.video_token_id,
"mrope_positions": mrope_positions,
"mrope_position_delta": mrope_position_delta,
}