diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 13ca29c54..78a9762ee 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -3,8 +3,9 @@ Multi-modality utils """ import hashlib +import pickle from abc import abstractmethod -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple import numpy as np import torch @@ -27,6 +28,130 @@ from sglang.utils import logger # propagation that can cause some log messages (like 'server is fired up') to not appear # in the console when multimodal support is enabled. +# TODO(mick): nccl +# cuda_ipc: for intranode tensor sharing +TensorTransportMode = Literal["cuda_ipc", "auto", "default"] + + +class TransportProxyTensor(torch.Tensor): + """ + A convenient torch.Tensor subclass that carries extra metadata and supports + efficient inter-process communications + """ + + @staticmethod + def __new__( + cls, + data: torch.Tensor, + name: Optional[str] = None, + fields: Optional[Dict[str, Any]] = None, + transport_mode: TensorTransportMode = "default", + *args, + **kwargs, + ): + + if not isinstance(data, torch.Tensor): + raise TypeError( + f"Input 'data' must be a torch.Tensor, but got {type(data)}" + ) + + instance = data.as_subclass(cls) + + instance._metadata = { + "name": name, + "fields": fields if fields is not None else {}, + "transport_mode": transport_mode, + } + + return instance + + def __getstate__(self): + """ + Called during pickling. Implements the serialization logic. + """ + # acquire all serialize metadata from _metadata + state = { + "metadata": self._metadata, + "tensor_data": None, + "ipc_extra": None, + } + + transport_mode = self._metadata.get("transport_mode", "default") + + if transport_mode == "cuda_ipc" and self.is_cuda: + try: + storage = self.untyped_storage() + handle = storage._share_cuda_() + + state["ipc_extra"] = { + "handle": handle, + "shape": self.shape, + "dtype": self.dtype, + "stride": self.stride(), + "device_index": self.device.index, + } + state["tensor_data"] = None + except Exception as e: + print_warning_once( + f"Warning: Failed to get CUDA IPC handle ({e}). Falling back to default transport." + ) + state["metadata"]["transport_mode"] = "default" + state["tensor_data"] = self.as_subclass(torch.Tensor) + else: + state["metadata"]["transport_mode"] = "default" + state["tensor_data"] = self.as_subclass(torch.Tensor) + + return state + + def __setstate__(self, state: Dict[str, Any]): + """ + Called during unpickling. Implements the deserialization logic. + """ + self._metadata = state["metadata"] + + transport_mode = self._metadata.get("transport_mode", "default") + + if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None: + ipc_extra = state["ipc_extra"] + handle, shape, dtype, stride, source_device_index = ( + ipc_extra["handle"], + ipc_extra["shape"], + ipc_extra["dtype"], + ipc_extra["stride"], + ipc_extra["device_index"], + ) + + try: + target_device = torch.device(f"cuda:{source_device_index}") + with torch.cuda.device(target_device): + storage = torch.UntypedStorage._new_shared_cuda(*handle) + reconstructed_tensor = torch.empty( + 0, dtype=dtype, device=target_device + ).set_(storage, storage_offset=0, size=shape, stride=stride) + self.set_(reconstructed_tensor) + except Exception as e: + print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).") + raise e + + elif state["tensor_data"] is not None: + self.set_(state["tensor_data"]) + else: + raise pickle.UnpicklingError( + "Invalid state for TransportProxyTensor: no tensor data found." + ) + + @property + def name(self) -> Optional[str]: + return self._metadata.get("name") + + @property + def fields(self) -> Dict[str, Any]: + return self._metadata.get("fields", {}) + + @property + def transport_mode(self) -> TensorTransportMode: + return self._metadata.get("transport_mode", "default") + class MultiModalityDataPaddingPattern: """ diff --git a/python/sglang/srt/managers/multimodal_processor.py b/python/sglang/srt/managers/multimodal_processor.py index 76679358a..51b6f3d92 100644 --- a/python/sglang/srt/managers/multimodal_processor.py +++ b/python/sglang/srt/managers/multimodal_processor.py @@ -12,18 +12,6 @@ logger = logging.getLogger(__name__) PROCESSOR_MAPPING = {} -class DummyMultimodalProcessor(BaseMultimodalProcessor): - def __init__(self): - pass - - async def process_mm_data_async(self, *args, **kwargs): - return None - - -def get_dummy_processor(): - return DummyMultimodalProcessor() - - def import_processors(): package_name = "sglang.srt.multimodal.processors" package = importlib.import_module(package_name) @@ -49,11 +37,12 @@ def import_processors(): def get_mm_processor( - hf_config, server_args: ServerArgs, processor + hf_config, server_args: ServerArgs, processor, transport_mode ) -> BaseMultimodalProcessor: for model_cls, processor_cls in PROCESSOR_MAPPING.items(): if model_cls.__name__ in hf_config.architectures: - return processor_cls(hf_config, server_args, processor) + return processor_cls(hf_config, server_args, processor, transport_mode) + raise ValueError( f"No processor registered for architecture: {hf_config.architectures}.\n" f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}" diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ad8bcf119..283da3394 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -209,10 +209,11 @@ class MultimodalDataItem: hash: int = None pad_value: int = None offsets: Optional[list] = None + # the raw features returned by processor, e.g. pixel_values or audio_features feature: Union[torch.Tensor, np.ndarray] = None - - # the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio + # the precomputed embeddings, passed as final encoder embeddings + # One and only one of the feature and precomputed_embeddings will be empty precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None # Model-specific data stored in a dictionary diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 0f65fa925..77c805aac 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -112,6 +112,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, ) +from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams @@ -166,6 +167,16 @@ class ReqState: output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) +def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: + is_cross_node = server_args.dist_init_addr + + if is_cross_node: + # Fallback to default CPU transport for multi-node + return "default" + else: + return "cuda_ipc" + + class TokenizerManager: """TokenizerManager is a process that tokenizes the text.""" @@ -216,12 +227,13 @@ class TokenizerManager: revision=server_args.revision, use_fast=not server_args.disable_fast_image_processor, ) + transport_mode = _determine_tensor_transport_mode(self.server_args) # We want to parallelize the image pre-processing so we create an executor for it # We create mm_processor for any skip_tokenizer_init to make sure we still encode # images even with skip_tokenizer_init=False. self.mm_processor = get_mm_processor( - self.model_config.hf_config, server_args, _processor + self.model_config.hf_config, server_args, _processor, transport_mode ) if server_args.skip_tokenizer_init: diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 3d548a19e..3f62a14d1 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -12,6 +12,7 @@ import torch from PIL import Image from transformers import BaseImageProcessorFast +from sglang.srt.managers.mm_utils import TransportProxyTensor from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.utils import load_audio, load_image, load_video, logger @@ -142,11 +143,14 @@ class MultimodalSpecialTokens: class BaseMultimodalProcessor(ABC): models = [] - def __init__(self, hf_config, server_args, _processor): + def __init__( + self, hf_config, server_args, _processor, transport_mode, *args, **kwargs + ): self.hf_config = hf_config self._processor = _processor self.arch = hf_config.architectures[0] self.server_args = server_args + self.transport_mode = transport_mode # FIXME: not accurate, model and image specific self.NUM_TOKEN_PER_FRAME = 330 @@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC): return_tensors="pt", **kwargs, ) - if "pixel_values" in result and isinstance( - result["pixel_values"], torch.Tensor - ): - result["pixel_values"] = result["pixel_values"].to("cpu") return result @abstractmethod @@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC): ) -> List[MultimodalDataItem]: """Create mm_items directly from processor output.""" items: dict[Modality, MultimodalDataItem] = {} - for attr_name, value in data_dict.items(): if attr_name == "input_ids": continue @@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC): mm_token_id=mm_token_id, ) + # post-process + for item in all_collected_items: + # replace the feature tensor with a proxy + if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda: + item.feature = TransportProxyTensor( + transport_mode=self.transport_mode, data=item.feature + ) + elif ( + isinstance(item.precomputed_embeddings, torch.Tensor) + and item.precomputed_embeddings.is_cuda + ): + item.precomputed_embeddings = TransportProxyTensor( + transport_mode=self.transport_mode, data=item.precomputed_embeddings + ) + return all_collected_items, input_ids, ret diff --git a/python/sglang/srt/multimodal/processors/clip.py b/python/sglang/srt/multimodal/processors/clip.py index 0925212cb..19ff71e78 100644 --- a/python/sglang/srt/multimodal/processors/clip.py +++ b/python/sglang/srt/multimodal/processors/clip.py @@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import ( class ClipImageProcessor(BaseMultimodalProcessor): models = [CLIPModel] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.mm_tokens = MultimodalSpecialTokens(image_token="").build( _processor ) diff --git a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py index 9847929f7..b09402d0b 100644 --- a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py +++ b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py @@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import ( class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): models = [DeepseekVL2ForCausalLM] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.mm_tokens = MultimodalSpecialTokens( image_token="", image_token_id=self._processor.image_token_id ).build(_processor) diff --git a/python/sglang/srt/multimodal/processors/gemma3.py b/python/sglang/srt/multimodal/processors/gemma3.py index 9abf172b2..cbfb45e84 100644 --- a/python/sglang/srt/multimodal/processors/gemma3.py +++ b/python/sglang/srt/multimodal/processors/gemma3.py @@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok class Gemma3SGLangImageProcessor(SGLangBaseProcessor): models = [Gemma3ForConditionalGeneration] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index self.mm_tokens = MultimodalSpecialTokens( diff --git a/python/sglang/srt/multimodal/processors/gemma3n.py b/python/sglang/srt/multimodal/processors/gemma3n.py index 938819d91..4bfbcaffa 100644 --- a/python/sglang/srt/multimodal/processors/gemma3n.py +++ b/python/sglang/srt/multimodal/processors/gemma3n.py @@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): models = [Gemma3nForConditionalGeneration] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.IM_START_TOKEN_ID = hf_config.boi_token_id self.IM_END_TOKEN_ID = hf_config.eoi_token_id diff --git a/python/sglang/srt/multimodal/processors/internvl.py b/python/sglang/srt/multimodal/processors/internvl.py index 12823077f..234d57d35 100644 --- a/python/sglang/srt/multimodal/processors/internvl.py +++ b/python/sglang/srt/multimodal/processors/internvl.py @@ -16,8 +16,8 @@ from sglang.srt.multimodal.processors.base_processor import ( class InternVLImageProcessor(BaseMultimodalProcessor): models = [InternVLChatModel] - def __init__(self, hf_config, server_args, _image_processor): - super().__init__(hf_config, server_args, _image_processor) + def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs): + super().__init__(hf_config, server_args, _image_processor, *args, **kwargs) image_size = hf_config.force_image_size or hf_config.vision_config.image_size patch_size = hf_config.vision_config.patch_size diff --git a/python/sglang/srt/multimodal/processors/janus_pro.py b/python/sglang/srt/multimodal/processors/janus_pro.py index 4dd8c1a84..54d6c1978 100644 --- a/python/sglang/srt/multimodal/processors/janus_pro.py +++ b/python/sglang/srt/multimodal/processors/janus_pro.py @@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import ( class JanusProImageProcessor(BaseMultimodalProcessor): models = [MultiModalityCausalLM] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.mm_tokens = MultimodalSpecialTokens( image_token=_processor.image_token, diff --git a/python/sglang/srt/multimodal/processors/kimi_vl.py b/python/sglang/srt/multimodal/processors/kimi_vl.py index 84c4a5133..541ed5c9e 100644 --- a/python/sglang/srt/multimodal/processors/kimi_vl.py +++ b/python/sglang/srt/multimodal/processors/kimi_vl.py @@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok class KimiVLImageProcessor(SGLangBaseProcessor): models = [KimiVLForConditionalGeneration] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.mm_tokens = MultimodalSpecialTokens( image_token="<|media_pad|>", # TODO: could we convert in MultimodalSpecialTokens? diff --git a/python/sglang/srt/multimodal/processors/llava.py b/python/sglang/srt/multimodal/processors/llava.py index f4504ecea..5031dccbd 100644 --- a/python/sglang/srt/multimodal/processors/llava.py +++ b/python/sglang/srt/multimodal/processors/llava.py @@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor): LlavaMistralForCausalLM, ] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) @staticmethod def _process_single_image_task( @@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor): f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`" ) - def __init__(self, hf_config, server_args, _processor): + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): assert hasattr(hf_config, "vision_config") assert hasattr(hf_config, "text_config") self.vision_config = hf_config.vision_config @@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor): if vision_type := getattr(self.vision_config, "model_type"): self.inner = self._get_sgl_processor_cls(vision_type)( - hf_config, server_args, _processor + hf_config, server_args, _processor, *args, **kwargs ) else: raise ValueError( diff --git a/python/sglang/srt/multimodal/processors/minicpm.py b/python/sglang/srt/multimodal/processors/minicpm.py index ed4f86511..9ddbf4fb6 100644 --- a/python/sglang/srt/multimodal/processors/minicpm.py +++ b/python/sglang/srt/multimodal/processors/minicpm.py @@ -15,8 +15,8 @@ from sglang.srt.multimodal.processors.base_processor import ( class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): models = [MiniCPMV, MiniCPMO] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) # Collect special token ids tokenizer = self._processor.tokenizer self.slice_start_id = getattr(tokenizer, "slice_start_id", None) @@ -26,7 +26,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): self.im_start_id = getattr(tokenizer, "im_start_id", None) self.im_end_id = getattr(tokenizer, "im_end_id", None) self.im_token_id = getattr(tokenizer, "unk_id", None) - self.mm_tokens = MultimodalSpecialTokens( image_token="(./)", audio_token="()", diff --git a/python/sglang/srt/multimodal/processors/mlama.py b/python/sglang/srt/multimodal/processors/mlama.py index dd3184452..432215a4f 100644 --- a/python/sglang/srt/multimodal/processors/mlama.py +++ b/python/sglang/srt/multimodal/processors/mlama.py @@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import ( class MllamaImageProcessor(BaseMultimodalProcessor): models = [MllamaForConditionalGeneration] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.mm_tokens = MultimodalSpecialTokens( image_token=self._processor.image_token, image_token_id=self._processor.image_token_id, diff --git a/python/sglang/srt/multimodal/processors/mllama4.py b/python/sglang/srt/multimodal/processors/mllama4.py index 2d0eba2fd..fd22d3848 100644 --- a/python/sglang/srt/multimodal/processors/mllama4.py +++ b/python/sglang/srt/multimodal/processors/mllama4.py @@ -18,8 +18,8 @@ from sglang.srt.multimodal.processors.base_processor import ( class Mllama4ImageProcessor(BaseMultimodalProcessor): models = [Llama4ForConditionalGeneration] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.vision_config = hf_config.vision_config self.text_config = hf_config.text_config self.boi_token_index = hf_config.boi_token_index diff --git a/python/sglang/srt/multimodal/processors/phi4mm.py b/python/sglang/srt/multimodal/processors/phi4mm.py index 720e3c132..1487d2ca2 100644 --- a/python/sglang/srt/multimodal/processors/phi4mm.py +++ b/python/sglang/srt/multimodal/processors/phi4mm.py @@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin): class Phi4MMMultimodalProcessor(BaseMultimodalProcessor): models = [Phi4MMForCausalLM] - def __init__(self, hf_config, server_args, _processor): + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.processor = Phi4MMProcessorAdapter(_processor) - super().__init__(hf_config, server_args, self.processor) + super().__init__(hf_config, server_args, self.processor, *args, **kwargs) # the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file # ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index fdfd6bd62..af5cedec9 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor): return ncols, nrows - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.IM_TOKEN_ID = getattr( hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID ) diff --git a/python/sglang/srt/multimodal/processors/qwen_audio.py b/python/sglang/srt/multimodal/processors/qwen_audio.py index 34d440375..b2bb38464 100644 --- a/python/sglang/srt/multimodal/processors/qwen_audio.py +++ b/python/sglang/srt/multimodal/processors/qwen_audio.py @@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import ( class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor): models = [Qwen2AudioForConditionalGeneration] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>" self.AUDIO_TOKEN_REGEX = re.compile( r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>" diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 1b1de4369..f67f72b95 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -201,8 +201,8 @@ async def preprocess_video( class Qwen2_5VLImageProcessor(SGLangBaseProcessor): models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) # The regex that matches expanded image tokens. self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id diff --git a/python/sglang/srt/multimodal/processors/vila.py b/python/sglang/srt/multimodal/processors/vila.py index 7070dfe73..5f9586b6c 100644 --- a/python/sglang/srt/multimodal/processors/vila.py +++ b/python/sglang/srt/multimodal/processors/vila.py @@ -34,8 +34,10 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): hf_config: PretrainedConfig, server_args: ServerArgs, _processor: VILAProcessor, + *args, + **kwargs, ) -> None: - super().__init__(hf_config, server_args, _processor) + super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.mm_tokens = MultimodalSpecialTokens( image_token=self._processor.tokenizer.image_token, image_token_id=hf_config.image_token_id, diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 83c653232..b7600b1a6 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -14,6 +14,7 @@ import traceback import urllib.request import weakref from concurrent.futures import ThreadPoolExecutor +from functools import wraps from io import BytesIO from json import dumps from typing import Any, Callable, List, Optional, Tuple, Type, Union @@ -28,6 +29,24 @@ from tqdm import tqdm logger = logging.getLogger(__name__) +def execute_once(func): + has_run = None + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal has_run + if not has_run: + func(*args, **kwargs) + has_run = True + + return wrapper + + +@execute_once +def info_once(message: str): + logger.info(message) + + def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str: """Convert a JSON schema to a string. Parameters diff --git a/test/srt/test_vlm_input_format.py b/test/srt/test_vlm_input_format.py index 79625ee82..b2cf0073d 100644 --- a/test/srt/test_vlm_input_format.py +++ b/test/srt/test_vlm_input_format.py @@ -24,7 +24,7 @@ class VLMInputTestBase: model_path = None chat_template = None processor = None - visual = None # Should be a callable for precomputed features + visual = None # Should be a callable for precomputed embeddings @classmethod def setUpClass(cls): @@ -41,7 +41,7 @@ class VLMInputTestBase: @classmethod def _init_visual(cls): - """Override in subclass to set up cls.visual as a callable for precomputed features.""" + """Override in subclass to set up cls.visual as a callable for precomputed embeddings.""" raise NotImplementedError def setUp(self):