vlm: optimize tensor transport (#6003)
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,7 @@ import torch
|
||||
from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.mm_utils import TransportProxyTensor
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import load_audio, load_image, load_video, logger
|
||||
|
||||
@@ -142,11 +143,14 @@ class MultimodalSpecialTokens:
|
||||
class BaseMultimodalProcessor(ABC):
|
||||
models = []
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
def __init__(
|
||||
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
|
||||
):
|
||||
self.hf_config = hf_config
|
||||
self._processor = _processor
|
||||
self.arch = hf_config.architectures[0]
|
||||
self.server_args = server_args
|
||||
self.transport_mode = transport_mode
|
||||
|
||||
# FIXME: not accurate, model and image specific
|
||||
self.NUM_TOKEN_PER_FRAME = 330
|
||||
@@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
if "pixel_values" in result and isinstance(
|
||||
result["pixel_values"], torch.Tensor
|
||||
):
|
||||
result["pixel_values"] = result["pixel_values"].to("cpu")
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
@@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
) -> List[MultimodalDataItem]:
|
||||
"""Create mm_items directly from processor output."""
|
||||
items: dict[Modality, MultimodalDataItem] = {}
|
||||
|
||||
for attr_name, value in data_dict.items():
|
||||
if attr_name == "input_ids":
|
||||
continue
|
||||
@@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC):
|
||||
mm_token_id=mm_token_id,
|
||||
)
|
||||
|
||||
# post-process
|
||||
for item in all_collected_items:
|
||||
# replace the feature tensor with a proxy
|
||||
if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda:
|
||||
item.feature = TransportProxyTensor(
|
||||
transport_mode=self.transport_mode, data=item.feature
|
||||
)
|
||||
elif (
|
||||
isinstance(item.precomputed_embeddings, torch.Tensor)
|
||||
and item.precomputed_embeddings.is_cuda
|
||||
):
|
||||
item.precomputed_embeddings = TransportProxyTensor(
|
||||
transport_mode=self.transport_mode, data=item.precomputed_embeddings
|
||||
)
|
||||
|
||||
return all_collected_items, input_ids, ret
|
||||
|
||||
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
models = [CLIPModel]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
||||
_processor
|
||||
)
|
||||
|
||||
@@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
models = [DeepseekVL2ForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<image>", image_token_id=self._processor.image_token_id
|
||||
).build(_processor)
|
||||
|
||||
@@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
|
||||
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
models = [Gemma3ForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
|
||||
@@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
|
||||
models = [Gemma3nForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
||||
|
||||
@@ -16,8 +16,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
models = [InternVLChatModel]
|
||||
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
super().__init__(hf_config, server_args, _image_processor)
|
||||
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
|
||||
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MultiModalityCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token,
|
||||
|
||||
@@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
|
||||
class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
models = [KimiVLForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<|media_pad|>",
|
||||
# TODO: could we convert in MultimodalSpecialTokens?
|
||||
|
||||
@@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
LlavaMistralForCausalLM,
|
||||
]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
@@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
||||
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
||||
)
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
assert hasattr(hf_config, "vision_config")
|
||||
assert hasattr(hf_config, "text_config")
|
||||
self.vision_config = hf_config.vision_config
|
||||
@@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
||||
|
||||
if vision_type := getattr(self.vision_config, "model_type"):
|
||||
self.inner = self._get_sgl_processor_cls(vision_type)(
|
||||
hf_config, server_args, _processor
|
||||
hf_config, server_args, _processor, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -15,8 +15,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [MiniCPMV, MiniCPMO]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
|
||||
@@ -26,7 +26,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
self.im_start_id = getattr(tokenizer, "im_start_id", None)
|
||||
self.im_end_id = getattr(tokenizer, "im_end_id", None)
|
||||
self.im_token_id = getattr(tokenizer, "unk_id", None)
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="(<image>./</image>)",
|
||||
audio_token="(<audio>./</audio>)",
|
||||
|
||||
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MllamaForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self._processor.image_token,
|
||||
image_token_id=self._processor.image_token_id,
|
||||
|
||||
@@ -18,8 +18,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
models = [Llama4ForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.text_config = hf_config.text_config
|
||||
self.boi_token_index = hf_config.boi_token_index
|
||||
|
||||
@@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
|
||||
class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [Phi4MMForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
self.processor = Phi4MMProcessorAdapter(_processor)
|
||||
super().__init__(hf_config, server_args, self.processor)
|
||||
super().__init__(hf_config, server_args, self.processor, *args, **kwargs)
|
||||
|
||||
# the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
|
||||
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
|
||||
|
||||
@@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.IM_TOKEN_ID = getattr(
|
||||
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
||||
)
|
||||
|
||||
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [Qwen2AudioForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
self.AUDIO_TOKEN_REGEX = re.compile(
|
||||
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
||||
|
||||
@@ -201,8 +201,8 @@ async def preprocess_video(
|
||||
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
|
||||
@@ -34,8 +34,10 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
hf_config: PretrainedConfig,
|
||||
server_args: ServerArgs,
|
||||
_processor: VILAProcessor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self._processor.tokenizer.image_token,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
|
||||
Reference in New Issue
Block a user