feat: update multimodal data handling in engine entrypoint (#8002)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -46,9 +46,9 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
ImageDataItem,
|
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
|
MultimodalDataInputFormat,
|
||||||
ReleaseMemoryOccupationReqInput,
|
ReleaseMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
RpcReqInput,
|
RpcReqInput,
|
||||||
@@ -148,13 +148,9 @@ class Engine(EngineBase):
|
|||||||
# - List of images (one per request in a batch)
|
# - List of images (one per request in a batch)
|
||||||
# - List of lists of images (multiple images per request)
|
# - List of lists of images (multiple images per request)
|
||||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||||
image_data: Optional[
|
image_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
Union[
|
audio_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
List[List[ImageDataItem]],
|
video_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
List[ImageDataItem],
|
|
||||||
ImageDataItem,
|
|
||||||
]
|
|
||||||
] = None,
|
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
@@ -187,6 +183,8 @@ class Engine(EngineBase):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
|
audio_data=audio_data,
|
||||||
|
video_data=video_data,
|
||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
logprob_start_len=logprob_start_len,
|
logprob_start_len=logprob_start_len,
|
||||||
top_logprobs_num=top_logprobs_num,
|
top_logprobs_num=top_logprobs_num,
|
||||||
@@ -231,13 +229,9 @@ class Engine(EngineBase):
|
|||||||
# - List of images (one per request in a batch)
|
# - List of images (one per request in a batch)
|
||||||
# - List of lists of images (multiple images per request)
|
# - List of lists of images (multiple images per request)
|
||||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||||
image_data: Optional[
|
image_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
Union[
|
audio_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
List[List[ImageDataItem]],
|
video_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
List[ImageDataItem],
|
|
||||||
ImageDataItem,
|
|
||||||
]
|
|
||||||
] = None,
|
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
@@ -272,6 +266,8 @@ class Engine(EngineBase):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
|
audio_data=audio_data,
|
||||||
|
video_data=video_data,
|
||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
logprob_start_len=logprob_start_len,
|
logprob_start_len=logprob_start_len,
|
||||||
top_logprobs_num=top_logprobs_num,
|
top_logprobs_num=top_logprobs_num,
|
||||||
@@ -295,19 +291,20 @@ class Engine(EngineBase):
|
|||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||||
image_data: Optional[
|
image_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
Union[
|
audio_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
List[List[Union[Image, str]]],
|
video_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
List[Union[Image, str]],
|
|
||||||
Union[Image, str],
|
|
||||||
]
|
|
||||||
] = None,
|
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
||||||
Please refer to `EmbeddingReqInput` for the documentation.
|
Please refer to `EmbeddingReqInput` for the documentation.
|
||||||
"""
|
"""
|
||||||
obj = EmbeddingReqInput(text=prompt, image_data=image_data)
|
obj = EmbeddingReqInput(
|
||||||
|
text=prompt,
|
||||||
|
image_data=image_data,
|
||||||
|
audio_data=audio_data,
|
||||||
|
video_data=video_data,
|
||||||
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||||
ret = loop.run_until_complete(generator.__anext__())
|
ret = loop.run_until_complete(generator.__anext__())
|
||||||
@@ -316,7 +313,9 @@ class Engine(EngineBase):
|
|||||||
async def async_encode(
|
async def async_encode(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||||
image_data: Optional[Union[List[str], str]] = None,
|
image_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
|
audio_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
|
video_data: Optional[MultimodalDataInputFormat] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Asynchronous version of encode method.
|
Asynchronous version of encode method.
|
||||||
@@ -324,7 +323,12 @@ class Engine(EngineBase):
|
|||||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
||||||
Please refer to `EmbeddingReqInput` for the documentation.
|
Please refer to `EmbeddingReqInput` for the documentation.
|
||||||
"""
|
"""
|
||||||
obj = EmbeddingReqInput(text=prompt, image_data=image_data)
|
obj = EmbeddingReqInput(
|
||||||
|
text=prompt,
|
||||||
|
image_data=image_data,
|
||||||
|
audio_data=audio_data,
|
||||||
|
video_data=video_data,
|
||||||
|
)
|
||||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||||
return await generator.__anext__()
|
return await generator.__anext__()
|
||||||
|
|
||||||
|
|||||||
@@ -42,8 +42,21 @@ class SessionParams:
|
|||||||
drop_previous_output: Optional[bool] = None
|
drop_previous_output: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
AudioDataItem = Union[str, Dict]
|
# Type definitions for multimodal input data
|
||||||
ImageDataItem = Union[Image, str, Dict]
|
# Individual data item types for each modality
|
||||||
|
ImageDataInputItem = Union[Image, str, Dict]
|
||||||
|
AudioDataInputItem = Union[str, Dict]
|
||||||
|
VideoDataInputItem = Union[str, Dict]
|
||||||
|
# Union type for any multimodal data item
|
||||||
|
MultimodalDataInputItem = Union[
|
||||||
|
ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
|
||||||
|
]
|
||||||
|
# Format types supporting single items, lists, or nested lists for batch processing
|
||||||
|
MultimodalDataInputFormat = Union[
|
||||||
|
List[List[MultimodalDataInputItem]],
|
||||||
|
List[MultimodalDataInputItem],
|
||||||
|
MultimodalDataInputItem,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -60,13 +73,11 @@ class GenerateReqInput:
|
|||||||
# - List of images (one per request in a batch)
|
# - List of images (one per request in a batch)
|
||||||
# - List of lists of images (multiple images per request)
|
# - List of lists of images (multiple images per request)
|
||||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||||
image_data: Optional[
|
image_data: Optional[MultimodalDataInputFormat] = None
|
||||||
Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
|
|
||||||
] = None
|
|
||||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
|
||||||
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
|
||||||
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||||
video_data: Optional[Union[List[List[str]], List[str], str]] = None
|
video_data: Optional[MultimodalDataInputFormat] = None
|
||||||
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||||
|
audio_data: Optional[MultimodalDataInputFormat] = None
|
||||||
# The sampling_params. See descriptions below.
|
# The sampling_params. See descriptions below.
|
||||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
# The request id.
|
# The request id.
|
||||||
@@ -524,13 +535,11 @@ class EmbeddingReqInput:
|
|||||||
# - List of images (one per request in a batch)
|
# - List of images (one per request in a batch)
|
||||||
# - List of lists of images (multiple images per request)
|
# - List of lists of images (multiple images per request)
|
||||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||||
image_data: Optional[
|
image_data: Optional[MultimodalDataInputFormat] = None
|
||||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
|
||||||
] = None
|
|
||||||
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||||
video_data: Optional[Union[List[str], str]] = None
|
video_data: Optional[MultimodalDataInputFormat] = None
|
||||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||||
audio_data: Optional[Union[List[str], str]] = None
|
audio_data: Optional[MultimodalDataInputFormat] = None
|
||||||
# The token ids for text; one can either specify text or input_ids.
|
# The token ids for text; one can either specify text or input_ids.
|
||||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
# The request id.
|
# The request id.
|
||||||
@@ -610,8 +619,6 @@ class EmbeddingReqInput:
|
|||||||
if self.is_cross_encoder_request:
|
if self.is_cross_encoder_request:
|
||||||
return EmbeddingReqInput(
|
return EmbeddingReqInput(
|
||||||
text=[self.text[i]] if self.text is not None else None,
|
text=[self.text[i]] if self.text is not None else None,
|
||||||
input_ids=None,
|
|
||||||
image_data=None,
|
|
||||||
sampling_params=self.sampling_params[i],
|
sampling_params=self.sampling_params[i],
|
||||||
rid=self.rid[i],
|
rid=self.rid[i],
|
||||||
is_cross_encoder_request=True,
|
is_cross_encoder_request=True,
|
||||||
@@ -621,6 +628,8 @@ class EmbeddingReqInput:
|
|||||||
text=self.text[i] if self.text is not None else None,
|
text=self.text[i] if self.text is not None else None,
|
||||||
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
||||||
image_data=self.image_data[i] if self.image_data is not None else None,
|
image_data=self.image_data[i] if self.image_data is not None else None,
|
||||||
|
audio_data=self.audio_data[i] if self.audio_data is not None else None,
|
||||||
|
video_data=self.video_data[i] if self.video_data is not None else None,
|
||||||
sampling_params=self.sampling_params[i],
|
sampling_params=self.sampling_params[i],
|
||||||
rid=self.rid[i],
|
rid=self.rid[i],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
ImageDataItem,
|
ImageDataInputItem,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.models.vila import VILAForConditionalGeneration
|
from sglang.srt.models.vila import VILAForConditionalGeneration
|
||||||
@@ -42,7 +42,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
image_data: Optional[ImageDataItem | List[ImageDataItem]],
|
image_data: Optional[ImageDataInputItem | List[ImageDataInputItem]],
|
||||||
input_text: str | List[int],
|
input_text: str | List[int],
|
||||||
request_obj: GenerateReqInput | EmbeddingReqInput,
|
request_obj: GenerateReqInput | EmbeddingReqInput,
|
||||||
max_req_input_len: int,
|
max_req_input_len: int,
|
||||||
|
|||||||
Reference in New Issue
Block a user