diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index e93efdeb9..e130dc227 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -46,9 +46,9 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, GetWeightsByNameReqInput, - ImageDataItem, InitWeightsUpdateGroupReqInput, LoadLoRAAdapterReqInput, + MultimodalDataInputFormat, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, RpcReqInput, @@ -148,13 +148,9 @@ class Engine(EngineBase): # - List of images (one per request in a batch) # - List of lists of images (multiple images per request) # See also python/sglang/srt/utils.py:load_image for more details. - image_data: Optional[ - Union[ - List[List[ImageDataItem]], - List[ImageDataItem], - ImageDataItem, - ] - ] = None, + image_data: Optional[MultimodalDataInputFormat] = None, + audio_data: Optional[MultimodalDataInputFormat] = None, + video_data: Optional[MultimodalDataInputFormat] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, @@ -187,6 +183,8 @@ class Engine(EngineBase): input_ids=input_ids, sampling_params=sampling_params, image_data=image_data, + audio_data=audio_data, + video_data=video_data, return_logprob=return_logprob, logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num, @@ -231,13 +229,9 @@ class Engine(EngineBase): # - List of images (one per request in a batch) # - List of lists of images (multiple images per request) # See also python/sglang/srt/utils.py:load_image for more details. - image_data: Optional[ - Union[ - List[List[ImageDataItem]], - List[ImageDataItem], - ImageDataItem, - ] - ] = None, + image_data: Optional[MultimodalDataInputFormat] = None, + audio_data: Optional[MultimodalDataInputFormat] = None, + video_data: Optional[MultimodalDataInputFormat] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, @@ -272,6 +266,8 @@ class Engine(EngineBase): input_ids=input_ids, sampling_params=sampling_params, image_data=image_data, + audio_data=audio_data, + video_data=video_data, return_logprob=return_logprob, logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num, @@ -295,19 +291,20 @@ class Engine(EngineBase): def encode( self, prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - image_data: Optional[ - Union[ - List[List[Union[Image, str]]], - List[Union[Image, str]], - Union[Image, str], - ] - ] = None, + image_data: Optional[MultimodalDataInputFormat] = None, + audio_data: Optional[MultimodalDataInputFormat] = None, + video_data: Optional[MultimodalDataInputFormat] = None, ) -> Dict: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. Please refer to `EmbeddingReqInput` for the documentation. """ - obj = EmbeddingReqInput(text=prompt, image_data=image_data) + obj = EmbeddingReqInput( + text=prompt, + image_data=image_data, + audio_data=audio_data, + video_data=video_data, + ) loop = asyncio.get_event_loop() generator = self.tokenizer_manager.generate_request(obj, None) ret = loop.run_until_complete(generator.__anext__()) @@ -316,7 +313,9 @@ class Engine(EngineBase): async def async_encode( self, prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - image_data: Optional[Union[List[str], str]] = None, + image_data: Optional[MultimodalDataInputFormat] = None, + audio_data: Optional[MultimodalDataInputFormat] = None, + video_data: Optional[MultimodalDataInputFormat] = None, ) -> Dict: """ Asynchronous version of encode method. @@ -324,7 +323,12 @@ class Engine(EngineBase): The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. Please refer to `EmbeddingReqInput` for the documentation. """ - obj = EmbeddingReqInput(text=prompt, image_data=image_data) + obj = EmbeddingReqInput( + text=prompt, + image_data=image_data, + audio_data=audio_data, + video_data=video_data, + ) generator = self.tokenizer_manager.generate_request(obj, None) return await generator.__anext__() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 745b7a2da..6eebf21e9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -42,8 +42,21 @@ class SessionParams: drop_previous_output: Optional[bool] = None -AudioDataItem = Union[str, Dict] -ImageDataItem = Union[Image, str, Dict] +# Type definitions for multimodal input data +# Individual data item types for each modality +ImageDataInputItem = Union[Image, str, Dict] +AudioDataInputItem = Union[str, Dict] +VideoDataInputItem = Union[str, Dict] +# Union type for any multimodal data item +MultimodalDataInputItem = Union[ + ImageDataInputItem, VideoDataInputItem, AudioDataInputItem +] +# Format types supporting single items, lists, or nested lists for batch processing +MultimodalDataInputFormat = Union[ + List[List[MultimodalDataInputItem]], + List[MultimodalDataInputItem], + MultimodalDataInputItem, +] @dataclass @@ -60,13 +73,11 @@ class GenerateReqInput: # - List of images (one per request in a batch) # - List of lists of images (multiple images per request) # See also python/sglang/srt/utils.py:load_image for more details. - image_data: Optional[ - Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem] - ] = None - # The audio input. Like image data, it can be a file name, a url, or base64 encoded string. - audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None + image_data: Optional[MultimodalDataInputFormat] = None # The video input. Like image data, it can be a file name, a url, or base64 encoded string. - video_data: Optional[Union[List[List[str]], List[str], str]] = None + video_data: Optional[MultimodalDataInputFormat] = None + # The audio input. Like image data, it can be a file name, a url, or base64 encoded string. + audio_data: Optional[MultimodalDataInputFormat] = None # The sampling_params. See descriptions below. sampling_params: Optional[Union[List[Dict], Dict]] = None # The request id. @@ -524,13 +535,11 @@ class EmbeddingReqInput: # - List of images (one per request in a batch) # - List of lists of images (multiple images per request) # See also python/sglang/srt/utils.py:load_image for more details. - image_data: Optional[ - Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]] - ] = None + image_data: Optional[MultimodalDataInputFormat] = None # The video input. Like image data, it can be a file name, a url, or base64 encoded string. - video_data: Optional[Union[List[str], str]] = None + video_data: Optional[MultimodalDataInputFormat] = None # The audio input. Like image data, it can be a file name, a url, or base64 encoded string. - audio_data: Optional[Union[List[str], str]] = None + audio_data: Optional[MultimodalDataInputFormat] = None # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None # The request id. @@ -610,8 +619,6 @@ class EmbeddingReqInput: if self.is_cross_encoder_request: return EmbeddingReqInput( text=[self.text[i]] if self.text is not None else None, - input_ids=None, - image_data=None, sampling_params=self.sampling_params[i], rid=self.rid[i], is_cross_encoder_request=True, @@ -621,6 +628,8 @@ class EmbeddingReqInput: text=self.text[i] if self.text is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None, image_data=self.image_data[i] if self.image_data is not None else None, + audio_data=self.audio_data[i] if self.audio_data is not None else None, + video_data=self.video_data[i] if self.video_data is not None else None, sampling_params=self.sampling_params[i], rid=self.rid[i], ) diff --git a/python/sglang/srt/multimodal/processors/vila.py b/python/sglang/srt/multimodal/processors/vila.py index a1625b9ed..c4d676c6d 100644 --- a/python/sglang/srt/multimodal/processors/vila.py +++ b/python/sglang/srt/multimodal/processors/vila.py @@ -8,7 +8,7 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, - ImageDataItem, + ImageDataInputItem, ) from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.vila import VILAForConditionalGeneration @@ -42,7 +42,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): async def process_mm_data_async( self, - image_data: Optional[ImageDataItem | List[ImageDataItem]], + image_data: Optional[ImageDataInputItem | List[ImageDataInputItem]], input_text: str | List[int], request_obj: GenerateReqInput | EmbeddingReqInput, max_req_input_len: int,