vlm: optimize tensor transport (#6003)

Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
Mick
2025-07-26 17:41:01 +08:00
committed by GitHub
parent 534756749a
commit 3212c2ad3f
23 changed files with 221 additions and 60 deletions

View File

@@ -12,6 +12,7 @@ import torch
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.managers.mm_utils import TransportProxyTensor
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import load_audio, load_image, load_video, logger
@@ -142,11 +143,14 @@ class MultimodalSpecialTokens:
class BaseMultimodalProcessor(ABC):
models = []
def __init__(self, hf_config, server_args, _processor):
def __init__(
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
):
self.hf_config = hf_config
self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args
self.transport_mode = transport_mode
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
@@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC):
return_tensors="pt",
**kwargs,
)
if "pixel_values" in result and isinstance(
result["pixel_values"], torch.Tensor
):
result["pixel_values"] = result["pixel_values"].to("cpu")
return result
@abstractmethod
@@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC):
) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output."""
items: dict[Modality, MultimodalDataItem] = {}
for attr_name, value in data_dict.items():
if attr_name == "input_ids":
continue
@@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC):
mm_token_id=mm_token_id,
)
# post-process
for item in all_collected_items:
# replace the feature tensor with a proxy
if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda:
item.feature = TransportProxyTensor(
transport_mode=self.transport_mode, data=item.feature
)
elif (
isinstance(item.precomputed_embeddings, torch.Tensor)
and item.precomputed_embeddings.is_cuda
):
item.precomputed_embeddings = TransportProxyTensor(
transport_mode=self.transport_mode, data=item.precomputed_embeddings
)
return all_collected_items, input_ids, ret

View File

@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class ClipImageProcessor(BaseMultimodalProcessor):
models = [CLIPModel]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
_processor
)

View File

@@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
models = [DeepseekVL2ForCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token="<image>", image_token_id=self._processor.image_token_id
).build(_processor)

View File

@@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
models = [Gemma3ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
self.mm_tokens = MultimodalSpecialTokens(

View File

@@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
models = [Gemma3nForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.IM_START_TOKEN_ID = hf_config.boi_token_id
self.IM_END_TOKEN_ID = hf_config.eoi_token_id

View File

@@ -16,8 +16,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class InternVLImageProcessor(BaseMultimodalProcessor):
models = [InternVLChatModel]
def __init__(self, hf_config, server_args, _image_processor):
super().__init__(hf_config, server_args, _image_processor)
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size

View File

@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token,

View File

@@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
class KimiVLImageProcessor(SGLangBaseProcessor):
models = [KimiVLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token="<|media_pad|>",
# TODO: could we convert in MultimodalSpecialTokens?

View File

@@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
LlavaMistralForCausalLM,
]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
@staticmethod
def _process_single_image_task(
@@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
)
def __init__(self, hf_config, server_args, _processor):
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
assert hasattr(hf_config, "vision_config")
assert hasattr(hf_config, "text_config")
self.vision_config = hf_config.vision_config
@@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
if vision_type := getattr(self.vision_config, "model_type"):
self.inner = self._get_sgl_processor_cls(vision_type)(
hf_config, server_args, _processor
hf_config, server_args, _processor, *args, **kwargs
)
else:
raise ValueError(

View File

@@ -15,8 +15,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
models = [MiniCPMV, MiniCPMO]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# Collect special token ids
tokenizer = self._processor.tokenizer
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
@@ -26,7 +26,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self.im_start_id = getattr(tokenizer, "im_start_id", None)
self.im_end_id = getattr(tokenizer, "im_end_id", None)
self.im_token_id = getattr(tokenizer, "unk_id", None)
self.mm_tokens = MultimodalSpecialTokens(
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",

View File

@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class MllamaImageProcessor(BaseMultimodalProcessor):
models = [MllamaForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token=self._processor.image_token,
image_token_id=self._processor.image_token_id,

View File

@@ -18,8 +18,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class Mllama4ImageProcessor(BaseMultimodalProcessor):
models = [Llama4ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.vision_config = hf_config.vision_config
self.text_config = hf_config.text_config
self.boi_token_index = hf_config.boi_token_index

View File

@@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
models = [Phi4MMForCausalLM]
def __init__(self, hf_config, server_args, _processor):
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
self.processor = Phi4MMProcessorAdapter(_processor)
super().__init__(hf_config, server_args, self.processor)
super().__init__(hf_config, server_args, self.processor, *args, **kwargs)
# the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py

View File

@@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor):
return ncols, nrows
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.IM_TOKEN_ID = getattr(
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
)

View File

@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
models = [Qwen2AudioForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"

View File

@@ -201,8 +201,8 @@ async def preprocess_video(
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# The regex that matches expanded image tokens.
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id

View File

@@ -34,8 +34,10 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
hf_config: PretrainedConfig,
server_args: ServerArgs,
_processor: VILAProcessor,
*args,
**kwargs,
) -> None:
super().__init__(hf_config, server_args, _processor)
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token=self._processor.tokenizer.image_token,
image_token_id=hf_config.image_token_id,