vlm: optimize tensor transport (#6003)
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -3,8 +3,9 @@ Multi-modality utils
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import pickle
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -27,6 +28,130 @@ from sglang.utils import logger
|
|||||||
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
||||||
# in the console when multimodal support is enabled.
|
# in the console when multimodal support is enabled.
|
||||||
|
|
||||||
|
# TODO(mick): nccl
|
||||||
|
# cuda_ipc: for intranode tensor sharing
|
||||||
|
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
|
||||||
|
|
||||||
|
|
||||||
|
class TransportProxyTensor(torch.Tensor):
|
||||||
|
"""
|
||||||
|
A convenient torch.Tensor subclass that carries extra metadata and supports
|
||||||
|
efficient inter-process communications
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
data: torch.Tensor,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
fields: Optional[Dict[str, Any]] = None,
|
||||||
|
transport_mode: TensorTransportMode = "default",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
if not isinstance(data, torch.Tensor):
|
||||||
|
raise TypeError(
|
||||||
|
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
instance = data.as_subclass(cls)
|
||||||
|
|
||||||
|
instance._metadata = {
|
||||||
|
"name": name,
|
||||||
|
"fields": fields if fields is not None else {},
|
||||||
|
"transport_mode": transport_mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
"""
|
||||||
|
Called during pickling. Implements the serialization logic.
|
||||||
|
"""
|
||||||
|
# acquire all serialize metadata from _metadata
|
||||||
|
state = {
|
||||||
|
"metadata": self._metadata,
|
||||||
|
"tensor_data": None,
|
||||||
|
"ipc_extra": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
transport_mode = self._metadata.get("transport_mode", "default")
|
||||||
|
|
||||||
|
if transport_mode == "cuda_ipc" and self.is_cuda:
|
||||||
|
try:
|
||||||
|
storage = self.untyped_storage()
|
||||||
|
handle = storage._share_cuda_()
|
||||||
|
|
||||||
|
state["ipc_extra"] = {
|
||||||
|
"handle": handle,
|
||||||
|
"shape": self.shape,
|
||||||
|
"dtype": self.dtype,
|
||||||
|
"stride": self.stride(),
|
||||||
|
"device_index": self.device.index,
|
||||||
|
}
|
||||||
|
state["tensor_data"] = None
|
||||||
|
except Exception as e:
|
||||||
|
print_warning_once(
|
||||||
|
f"Warning: Failed to get CUDA IPC handle ({e}). Falling back to default transport."
|
||||||
|
)
|
||||||
|
state["metadata"]["transport_mode"] = "default"
|
||||||
|
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
||||||
|
else:
|
||||||
|
state["metadata"]["transport_mode"] = "default"
|
||||||
|
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Called during unpickling. Implements the deserialization logic.
|
||||||
|
"""
|
||||||
|
self._metadata = state["metadata"]
|
||||||
|
|
||||||
|
transport_mode = self._metadata.get("transport_mode", "default")
|
||||||
|
|
||||||
|
if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
|
||||||
|
ipc_extra = state["ipc_extra"]
|
||||||
|
handle, shape, dtype, stride, source_device_index = (
|
||||||
|
ipc_extra["handle"],
|
||||||
|
ipc_extra["shape"],
|
||||||
|
ipc_extra["dtype"],
|
||||||
|
ipc_extra["stride"],
|
||||||
|
ipc_extra["device_index"],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
target_device = torch.device(f"cuda:{source_device_index}")
|
||||||
|
with torch.cuda.device(target_device):
|
||||||
|
storage = torch.UntypedStorage._new_shared_cuda(*handle)
|
||||||
|
reconstructed_tensor = torch.empty(
|
||||||
|
0, dtype=dtype, device=target_device
|
||||||
|
).set_(storage, storage_offset=0, size=shape, stride=stride)
|
||||||
|
self.set_(reconstructed_tensor)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
elif state["tensor_data"] is not None:
|
||||||
|
self.set_(state["tensor_data"])
|
||||||
|
else:
|
||||||
|
raise pickle.UnpicklingError(
|
||||||
|
"Invalid state for TransportProxyTensor: no tensor data found."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> Optional[str]:
|
||||||
|
return self._metadata.get("name")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fields(self) -> Dict[str, Any]:
|
||||||
|
return self._metadata.get("fields", {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transport_mode(self) -> TensorTransportMode:
|
||||||
|
return self._metadata.get("transport_mode", "default")
|
||||||
|
|
||||||
|
|
||||||
class MultiModalityDataPaddingPattern:
|
class MultiModalityDataPaddingPattern:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__)
|
|||||||
PROCESSOR_MAPPING = {}
|
PROCESSOR_MAPPING = {}
|
||||||
|
|
||||||
|
|
||||||
class DummyMultimodalProcessor(BaseMultimodalProcessor):
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def process_mm_data_async(self, *args, **kwargs):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_dummy_processor():
|
|
||||||
return DummyMultimodalProcessor()
|
|
||||||
|
|
||||||
|
|
||||||
def import_processors():
|
def import_processors():
|
||||||
package_name = "sglang.srt.multimodal.processors"
|
package_name = "sglang.srt.multimodal.processors"
|
||||||
package = importlib.import_module(package_name)
|
package = importlib.import_module(package_name)
|
||||||
@@ -49,11 +37,12 @@ def import_processors():
|
|||||||
|
|
||||||
|
|
||||||
def get_mm_processor(
|
def get_mm_processor(
|
||||||
hf_config, server_args: ServerArgs, processor
|
hf_config, server_args: ServerArgs, processor, transport_mode
|
||||||
) -> BaseMultimodalProcessor:
|
) -> BaseMultimodalProcessor:
|
||||||
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
|
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
|
||||||
if model_cls.__name__ in hf_config.architectures:
|
if model_cls.__name__ in hf_config.architectures:
|
||||||
return processor_cls(hf_config, server_args, processor)
|
return processor_cls(hf_config, server_args, processor, transport_mode)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
||||||
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
||||||
|
|||||||
@@ -209,10 +209,11 @@ class MultimodalDataItem:
|
|||||||
hash: int = None
|
hash: int = None
|
||||||
pad_value: int = None
|
pad_value: int = None
|
||||||
offsets: Optional[list] = None
|
offsets: Optional[list] = None
|
||||||
|
|
||||||
# the raw features returned by processor, e.g. pixel_values or audio_features
|
# the raw features returned by processor, e.g. pixel_values or audio_features
|
||||||
feature: Union[torch.Tensor, np.ndarray] = None
|
feature: Union[torch.Tensor, np.ndarray] = None
|
||||||
|
# the precomputed embeddings, passed as final encoder embeddings
|
||||||
# the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
|
# One and only one of the feature and precomputed_embeddings will be empty
|
||||||
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
|
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||||
|
|
||||||
# Model-specific data stored in a dictionary
|
# Model-specific data stored in a dictionary
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
UpdateWeightsFromTensorReqOutput,
|
UpdateWeightsFromTensorReqOutput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
||||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -166,6 +167,16 @@ class ReqState:
|
|||||||
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
||||||
|
is_cross_node = server_args.dist_init_addr
|
||||||
|
|
||||||
|
if is_cross_node:
|
||||||
|
# Fallback to default CPU transport for multi-node
|
||||||
|
return "default"
|
||||||
|
else:
|
||||||
|
return "cuda_ipc"
|
||||||
|
|
||||||
|
|
||||||
class TokenizerManager:
|
class TokenizerManager:
|
||||||
"""TokenizerManager is a process that tokenizes the text."""
|
"""TokenizerManager is a process that tokenizes the text."""
|
||||||
|
|
||||||
@@ -216,12 +227,13 @@ class TokenizerManager:
|
|||||||
revision=server_args.revision,
|
revision=server_args.revision,
|
||||||
use_fast=not server_args.disable_fast_image_processor,
|
use_fast=not server_args.disable_fast_image_processor,
|
||||||
)
|
)
|
||||||
|
transport_mode = _determine_tensor_transport_mode(self.server_args)
|
||||||
|
|
||||||
# We want to parallelize the image pre-processing so we create an executor for it
|
# We want to parallelize the image pre-processing so we create an executor for it
|
||||||
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
||||||
# images even with skip_tokenizer_init=False.
|
# images even with skip_tokenizer_init=False.
|
||||||
self.mm_processor = get_mm_processor(
|
self.mm_processor = get_mm_processor(
|
||||||
self.model_config.hf_config, server_args, _processor
|
self.model_config.hf_config, server_args, _processor, transport_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import BaseImageProcessorFast
|
from transformers import BaseImageProcessorFast
|
||||||
|
|
||||||
|
from sglang.srt.managers.mm_utils import TransportProxyTensor
|
||||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.utils import load_audio, load_image, load_video, logger
|
from sglang.srt.utils import load_audio, load_image, load_video, logger
|
||||||
|
|
||||||
@@ -142,11 +143,14 @@ class MultimodalSpecialTokens:
|
|||||||
class BaseMultimodalProcessor(ABC):
|
class BaseMultimodalProcessor(ABC):
|
||||||
models = []
|
models = []
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(
|
||||||
|
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
|
||||||
|
):
|
||||||
self.hf_config = hf_config
|
self.hf_config = hf_config
|
||||||
self._processor = _processor
|
self._processor = _processor
|
||||||
self.arch = hf_config.architectures[0]
|
self.arch = hf_config.architectures[0]
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
|
self.transport_mode = transport_mode
|
||||||
|
|
||||||
# FIXME: not accurate, model and image specific
|
# FIXME: not accurate, model and image specific
|
||||||
self.NUM_TOKEN_PER_FRAME = 330
|
self.NUM_TOKEN_PER_FRAME = 330
|
||||||
@@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if "pixel_values" in result and isinstance(
|
|
||||||
result["pixel_values"], torch.Tensor
|
|
||||||
):
|
|
||||||
result["pixel_values"] = result["pixel_values"].to("cpu")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
) -> List[MultimodalDataItem]:
|
) -> List[MultimodalDataItem]:
|
||||||
"""Create mm_items directly from processor output."""
|
"""Create mm_items directly from processor output."""
|
||||||
items: dict[Modality, MultimodalDataItem] = {}
|
items: dict[Modality, MultimodalDataItem] = {}
|
||||||
|
|
||||||
for attr_name, value in data_dict.items():
|
for attr_name, value in data_dict.items():
|
||||||
if attr_name == "input_ids":
|
if attr_name == "input_ids":
|
||||||
continue
|
continue
|
||||||
@@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
mm_token_id=mm_token_id,
|
mm_token_id=mm_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# post-process
|
||||||
|
for item in all_collected_items:
|
||||||
|
# replace the feature tensor with a proxy
|
||||||
|
if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda:
|
||||||
|
item.feature = TransportProxyTensor(
|
||||||
|
transport_mode=self.transport_mode, data=item.feature
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
isinstance(item.precomputed_embeddings, torch.Tensor)
|
||||||
|
and item.precomputed_embeddings.is_cuda
|
||||||
|
):
|
||||||
|
item.precomputed_embeddings = TransportProxyTensor(
|
||||||
|
transport_mode=self.transport_mode, data=item.precomputed_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
return all_collected_items, input_ids, ret
|
return all_collected_items, input_ids, ret
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|||||||
class ClipImageProcessor(BaseMultimodalProcessor):
|
class ClipImageProcessor(BaseMultimodalProcessor):
|
||||||
models = [CLIPModel]
|
models = [CLIPModel]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
||||||
_processor
|
_processor
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|||||||
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||||
models = [DeepseekVL2ForCausalLM]
|
models = [DeepseekVL2ForCausalLM]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
self.mm_tokens = MultimodalSpecialTokens(
|
self.mm_tokens = MultimodalSpecialTokens(
|
||||||
image_token="<image>", image_token_id=self._processor.image_token_id
|
image_token="<image>", image_token_id=self._processor.image_token_id
|
||||||
).build(_processor)
|
).build(_processor)
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
|
|||||||
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||||
models = [Gemma3ForConditionalGeneration]
|
models = [Gemma3ForConditionalGeneration]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||||
self.mm_tokens = MultimodalSpecialTokens(
|
self.mm_tokens = MultimodalSpecialTokens(
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
|||||||
|
|
||||||
models = [Gemma3nForConditionalGeneration]
|
models = [Gemma3nForConditionalGeneration]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
|
|
||||||
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
||||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|||||||
class InternVLImageProcessor(BaseMultimodalProcessor):
|
class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||||
models = [InternVLChatModel]
|
models = [InternVLChatModel]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _image_processor):
|
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _image_processor)
|
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
|
||||||
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
|
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
|
||||||
patch_size = hf_config.vision_config.patch_size
|
patch_size = hf_config.vision_config.patch_size
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|||||||
class JanusProImageProcessor(BaseMultimodalProcessor):
|
class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||||
models = [MultiModalityCausalLM]
|
models = [MultiModalityCausalLM]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
|
|
||||||
self.mm_tokens = MultimodalSpecialTokens(
|
self.mm_tokens = MultimodalSpecialTokens(
|
||||||
image_token=_processor.image_token,
|
image_token=_processor.image_token,
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
|
|||||||
class KimiVLImageProcessor(SGLangBaseProcessor):
|
class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||||
models = [KimiVLForConditionalGeneration]
|
models = [KimiVLForConditionalGeneration]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
self.mm_tokens = MultimodalSpecialTokens(
|
self.mm_tokens = MultimodalSpecialTokens(
|
||||||
image_token="<|media_pad|>",
|
image_token="<|media_pad|>",
|
||||||
# TODO: could we convert in MultimodalSpecialTokens?
|
# TODO: could we convert in MultimodalSpecialTokens?
|
||||||
|
|||||||
@@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|||||||
LlavaMistralForCausalLM,
|
LlavaMistralForCausalLM,
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_single_image_task(
|
def _process_single_image_task(
|
||||||
@@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
assert hasattr(hf_config, "vision_config")
|
assert hasattr(hf_config, "vision_config")
|
||||||
assert hasattr(hf_config, "text_config")
|
assert hasattr(hf_config, "text_config")
|
||||||
self.vision_config = hf_config.vision_config
|
self.vision_config = hf_config.vision_config
|
||||||
@@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
|
|
||||||
if vision_type := getattr(self.vision_config, "model_type"):
|
if vision_type := getattr(self.vision_config, "model_type"):
|
||||||
self.inner = self._get_sgl_processor_cls(vision_type)(
|
self.inner = self._get_sgl_processor_cls(vision_type)(
|
||||||
hf_config, server_args, _processor
|
hf_config, server_args, _processor, *args, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|||||||
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||||
models = [MiniCPMV, MiniCPMO]
|
models = [MiniCPMV, MiniCPMO]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
# Collect special token ids
|
# Collect special token ids
|
||||||
tokenizer = self._processor.tokenizer
|
tokenizer = self._processor.tokenizer
|
||||||
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
|
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
|
||||||
@@ -26,7 +26,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
self.im_start_id = getattr(tokenizer, "im_start_id", None)
|
self.im_start_id = getattr(tokenizer, "im_start_id", None)
|
||||||
self.im_end_id = getattr(tokenizer, "im_end_id", None)
|
self.im_end_id = getattr(tokenizer, "im_end_id", None)
|
||||||
self.im_token_id = getattr(tokenizer, "unk_id", None)
|
self.im_token_id = getattr(tokenizer, "unk_id", None)
|
||||||
|
|
||||||
self.mm_tokens = MultimodalSpecialTokens(
|
self.mm_tokens = MultimodalSpecialTokens(
|
||||||
image_token="(<image>./</image>)",
|
image_token="(<image>./</image>)",
|
||||||
audio_token="(<audio>./</audio>)",
|
audio_token="(<audio>./</audio>)",
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|||||||
class MllamaImageProcessor(BaseMultimodalProcessor):
|
class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||||
models = [MllamaForConditionalGeneration]
|
models = [MllamaForConditionalGeneration]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
self.mm_tokens = MultimodalSpecialTokens(
|
self.mm_tokens = MultimodalSpecialTokens(
|
||||||
image_token=self._processor.image_token,
|
image_token=self._processor.image_token,
|
||||||
image_token_id=self._processor.image_token_id,
|
image_token_id=self._processor.image_token_id,
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|||||||
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||||
models = [Llama4ForConditionalGeneration]
|
models = [Llama4ForConditionalGeneration]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
self.vision_config = hf_config.vision_config
|
self.vision_config = hf_config.vision_config
|
||||||
self.text_config = hf_config.text_config
|
self.text_config = hf_config.text_config
|
||||||
self.boi_token_index = hf_config.boi_token_index
|
self.boi_token_index = hf_config.boi_token_index
|
||||||
|
|||||||
@@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
|
|||||||
class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
||||||
models = [Phi4MMForCausalLM]
|
models = [Phi4MMForCausalLM]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
self.processor = Phi4MMProcessorAdapter(_processor)
|
self.processor = Phi4MMProcessorAdapter(_processor)
|
||||||
super().__init__(hf_config, server_args, self.processor)
|
super().__init__(hf_config, server_args, self.processor, *args, **kwargs)
|
||||||
|
|
||||||
# the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
|
# the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
|
||||||
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
|
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
|
||||||
|
|||||||
@@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
|||||||
|
|
||||||
return ncols, nrows
|
return ncols, nrows
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
self.IM_TOKEN_ID = getattr(
|
self.IM_TOKEN_ID = getattr(
|
||||||
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|||||||
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
||||||
models = [Qwen2AudioForConditionalGeneration]
|
models = [Qwen2AudioForConditionalGeneration]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||||
self.AUDIO_TOKEN_REGEX = re.compile(
|
self.AUDIO_TOKEN_REGEX = re.compile(
|
||||||
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
||||||
|
|||||||
@@ -201,8 +201,8 @@ async def preprocess_video(
|
|||||||
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||||
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
|
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
# The regex that matches expanded image tokens.
|
# The regex that matches expanded image tokens.
|
||||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||||
|
|||||||
@@ -34,8 +34,10 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
hf_config: PretrainedConfig,
|
hf_config: PretrainedConfig,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
_processor: VILAProcessor,
|
_processor: VILAProcessor,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
self.mm_tokens = MultimodalSpecialTokens(
|
self.mm_tokens = MultimodalSpecialTokens(
|
||||||
image_token=self._processor.tokenizer.image_token,
|
image_token=self._processor.tokenizer.image_token,
|
||||||
image_token_id=hf_config.image_token_id,
|
image_token_id=hf_config.image_token_id,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import traceback
|
|||||||
import urllib.request
|
import urllib.request
|
||||||
import weakref
|
import weakref
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import wraps
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from json import dumps
|
from json import dumps
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
||||||
@@ -28,6 +29,24 @@ from tqdm import tqdm
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def execute_once(func):
|
||||||
|
has_run = None
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
nonlocal has_run
|
||||||
|
if not has_run:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
has_run = True
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@execute_once
|
||||||
|
def info_once(message: str):
|
||||||
|
logger.info(message)
|
||||||
|
|
||||||
|
|
||||||
def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
|
def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
|
||||||
"""Convert a JSON schema to a string.
|
"""Convert a JSON schema to a string.
|
||||||
Parameters
|
Parameters
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class VLMInputTestBase:
|
|||||||
model_path = None
|
model_path = None
|
||||||
chat_template = None
|
chat_template = None
|
||||||
processor = None
|
processor = None
|
||||||
visual = None # Should be a callable for precomputed features
|
visual = None # Should be a callable for precomputed embeddings
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -41,7 +41,7 @@ class VLMInputTestBase:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _init_visual(cls):
|
def _init_visual(cls):
|
||||||
"""Override in subclass to set up cls.visual as a callable for precomputed features."""
|
"""Override in subclass to set up cls.visual as a callable for precomputed embeddings."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user