[feat&refactor] Enhance multimodal input support with refactor io_struct (#4938)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
XinyuanTong
2025-04-08 14:48:07 -07:00
committed by GitHub
parent f8194b267c
commit d09a51f1f6
4 changed files with 811 additions and 104 deletions

View File

@@ -29,6 +29,7 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import zmq
import zmq.asyncio
from PIL.Image import Image
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -135,9 +136,19 @@ class Engine:
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
# Can be formatted as:
# - Single image for a single request
# - List of images (one per request in a batch)
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[
List[List[Union[Image, str]]],
List[Union[Image, str]],
Union[Image, str],
]
] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -190,9 +201,19 @@ class Engine:
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
# Can be formatted as:
# - Single image for a single request
# - List of images (one per request in a batch)
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[
List[List[Union[Image, str]]],
List[Union[Image, str]],
Union[Image, str],
]
] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -228,7 +249,13 @@ class Engine:
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
image_data: Optional[Union[List[str], str]] = None,
image_data: Optional[
Union[
List[List[Union[Image, str]]],
List[Union[Image, str]],
Union[Image, str],
]
] = None,
) -> Dict:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.

View File

@@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from PIL.Image import Image
from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
@@ -56,9 +57,19 @@ class VerlEngine:
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
# Can be formatted as:
# - Single image for a single request
# - List of images (one per request in a batch)
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[
List[List[Union[Image, str]]],
List[Union[Image, str]],
Union[Image, str],
]
] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,

View File

@@ -20,7 +20,13 @@ import copy
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
# handle serialization of Image for pydantic
if TYPE_CHECKING:
from PIL.Image import Image
else:
Image = Any
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -42,10 +48,16 @@ class GenerateReqInput:
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
# Can be formatted as:
# - Single image for a single request
# - List of images (one per request in a batch)
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None
@@ -84,6 +96,31 @@ class GenerateReqInput:
return_hidden_states: bool = False
def normalize_batch_and_arguments(self):
"""
Normalize the batch size and arguments for the request.
This method resolves various input formats and ensures all parameters
are properly formatted as either single values or batches depending on the input.
It also handles parallel sampling expansion and sets default values for
unspecified parameters.
Raises:
ValueError: If inputs are not properly specified (e.g., none or all of
text, input_ids, input_embeds are provided)
"""
self._validate_inputs()
self._determine_batch_size()
self._handle_parallel_sampling()
if self.is_single:
self._normalize_single_inputs()
else:
self._normalize_batch_inputs()
self._validate_session_params()
def _validate_inputs(self):
"""Validate that the input configuration is valid."""
if (
self.text is None and self.input_ids is None and self.input_embeds is None
) or (
@@ -95,7 +132,8 @@ class GenerateReqInput:
"Either text, input_ids or input_embeds should be provided."
)
# Derive the batch size
def _determine_batch_size(self):
"""Determine if this is a single example or a batch and the batch size."""
if self.text is not None:
if isinstance(self.text, str):
self.is_single = True
@@ -119,21 +157,25 @@ class GenerateReqInput:
self.is_single = True
self.batch_size = 1
else:
self.is_single = False
self.batch_size = len(self.input_embeds)
# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
def _handle_parallel_sampling(self):
"""Handle parallel sampling parameters and adjust batch size if needed."""
# Determine parallel sample count
if self.sampling_params is None:
self.parallel_sample_num = 1
elif isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list):
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
assert all(
self.parallel_sample_num == sampling_params.get("n", 1)
for sampling_params in self.sampling_params
), "The parallel_sample_num should be the same for all samples in sample params."
for sampling_params in self.sampling_params:
if self.parallel_sample_num != sampling_params.get("n", 1):
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
# If using parallel sampling with a single example, convert to batch
if self.parallel_sample_num > 1 and self.is_single:
self.is_single = False
if self.text is not None:
@@ -141,97 +183,190 @@ class GenerateReqInput:
if self.input_ids is not None:
self.input_ids = [self.input_ids]
# Fill in default arguments
if self.is_single:
if self.sampling_params is None:
self.sampling_params = {}
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.return_logprob is None:
self.return_logprob = False
if self.logprob_start_len is None:
self.logprob_start_len = -1
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = None
def _normalize_single_inputs(self):
"""Normalize inputs for a single example."""
if self.sampling_params is None:
self.sampling_params = {}
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.return_logprob is None:
self.return_logprob = False
if self.logprob_start_len is None:
self.logprob_start_len = -1
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = None
def _normalize_batch_inputs(self):
"""Normalize inputs for a batch of examples, including parallel sampling expansion."""
# Calculate expanded batch size
if self.parallel_sample_num == 1:
num = self.batch_size
else:
if self.parallel_sample_num == 1:
num = self.batch_size
# Expand parallel_sample_num
num = self.batch_size * self.parallel_sample_num
# Expand input based on type
self._expand_inputs(num)
self._normalize_lora_paths(num)
self._normalize_image_data(num)
self._normalize_audio_data(num)
self._normalize_sampling_params(num)
self._normalize_rid(num)
self._normalize_logprob_params(num)
self._normalize_custom_logit_processor(num)
def _expand_inputs(self, num):
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
if self.text is not None:
if not isinstance(self.text, list):
raise ValueError("Text should be a list for batch processing.")
self.text = self.text * self.parallel_sample_num
elif self.input_ids is not None:
if not isinstance(self.input_ids, list) or not isinstance(
self.input_ids[0], list
):
raise ValueError(
"input_ids should be a list of lists for batch processing."
)
self.input_ids = self.input_ids * self.parallel_sample_num
elif self.input_embeds is not None:
if not isinstance(self.input_embeds, list):
raise ValueError("input_embeds should be a list for batch processing.")
self.input_embeds = self.input_embeds * self.parallel_sample_num
def _normalize_lora_paths(self, num):
"""Normalize LoRA paths for batch processing."""
if self.lora_path is not None:
if isinstance(self.lora_path, str):
self.lora_path = [self.lora_path] * num
elif isinstance(self.lora_path, list):
self.lora_path = self.lora_path * self.parallel_sample_num
else:
raise ValueError("lora_path should be a list or a string.")
def _normalize_image_data(self, num):
"""Normalize image data for batch processing."""
if self.image_data is None:
self.image_data = [None] * num
elif not isinstance(self.image_data, list):
# Single image, convert to list of single-image lists
self.image_data = [[self.image_data]] * num
self.modalities = ["image"] * num
elif isinstance(self.image_data, list):
if len(self.image_data) != self.batch_size:
raise ValueError(
"The length of image_data should be equal to the batch size."
)
self.modalities = []
if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
# Already a list of lists, keep as is
for i in range(len(self.image_data)):
if self.image_data[i] is None or self.image_data[i] == [None]:
self.modalities.append(None)
elif len(self.image_data[i]) == 1:
self.modalities.append("image")
elif len(self.image_data[i]) > 1:
self.modalities.append("multi-images")
# Expand parallel_sample_num
num = self.batch_size * self.parallel_sample_num
if not self.image_data:
self.image_data = [None] * num
elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num
elif isinstance(self.image_data, list):
pass
if self.audio_data is None:
self.audio_data = [None] * num
elif not isinstance(self.audio_data, list):
self.audio_data = [self.audio_data] * num
elif isinstance(self.audio_data, list):
pass
if self.sampling_params is None:
self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list):
self.sampling_params = [self.sampling_params] * num
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
self.image_data = self.image_data * self.parallel_sample_num
self.modalities = self.modalities * self.parallel_sample_num
else:
assert isinstance(self.rid, list), "The rid should be a list."
# List of images for a batch, wrap each in a list
wrapped_images = [[img] for img in self.image_data]
# Expand for parallel sampling
self.image_data = wrapped_images * self.parallel_sample_num
self.modalities = ["image"] * num
if self.return_logprob is None:
self.return_logprob = [False] * num
elif not isinstance(self.return_logprob, list):
self.return_logprob = [self.return_logprob] * num
def _normalize_audio_data(self, num):
"""Normalize audio data for batch processing."""
if self.audio_data is None:
self.audio_data = [None] * num
elif not isinstance(self.audio_data, list):
self.audio_data = [self.audio_data] * num
elif isinstance(self.audio_data, list):
self.audio_data = self.audio_data * self.parallel_sample_num
def _normalize_sampling_params(self, num):
"""Normalize sampling parameters for batch processing."""
if self.sampling_params is None:
self.sampling_params = [{}] * num
elif isinstance(self.sampling_params, dict):
self.sampling_params = [self.sampling_params] * num
else: # Already a list
self.sampling_params = self.sampling_params * self.parallel_sample_num
def _normalize_rid(self, num):
"""Normalize request IDs for batch processing."""
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
elif not isinstance(self.rid, list):
raise ValueError("The rid should be a list for batch processing.")
def _normalize_logprob_params(self, num):
"""Normalize logprob-related parameters for batch processing."""
# Helper function to normalize a parameter
def normalize_param(param, default_value, param_name):
if param is None:
return [default_value] * num
elif not isinstance(param, list):
return [param] * num
else:
assert self.parallel_sample_num == 1
if self.parallel_sample_num > 1:
raise ValueError(
f"Cannot use list {param_name} with parallel_sample_num > 1"
)
return param
if self.logprob_start_len is None:
self.logprob_start_len = [-1] * num
elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num
else:
assert self.parallel_sample_num == 1
# Normalize each logprob parameter
self.return_logprob = normalize_param(
self.return_logprob, False, "return_logprob"
)
self.logprob_start_len = normalize_param(
self.logprob_start_len, -1, "logprob_start_len"
)
self.top_logprobs_num = normalize_param(
self.top_logprobs_num, 0, "top_logprobs_num"
)
if self.top_logprobs_num is None:
self.top_logprobs_num = [0] * num
elif not isinstance(self.top_logprobs_num, list):
self.top_logprobs_num = [self.top_logprobs_num] * num
else:
assert self.parallel_sample_num == 1
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = [None] * num
elif not isinstance(self.token_ids_logprob, list):
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
elif not isinstance(self.token_ids_logprob[0], list):
self.token_ids_logprob = [
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
]
else:
assert self.parallel_sample_num == 1
if self.custom_logit_processor is None:
self.custom_logit_processor = [None] * num
elif not isinstance(self.custom_logit_processor, list):
self.custom_logit_processor = [self.custom_logit_processor] * num
else:
assert self.parallel_sample_num == 1
# Other checks
if self.session_params is not None:
assert isinstance(self.session_params, dict) or isinstance(
self.session_params[0], dict
# Handle token_ids_logprob specially due to its nested structure
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = [None] * num
elif not isinstance(self.token_ids_logprob, list):
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
elif not isinstance(self.token_ids_logprob[0], list):
self.token_ids_logprob = [
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
]
elif self.parallel_sample_num > 1:
raise ValueError(
"Cannot use list token_ids_logprob with parallel_sample_num > 1"
)
def _normalize_custom_logit_processor(self, num):
"""Normalize custom logit processor for batch processing."""
if self.custom_logit_processor is None:
self.custom_logit_processor = [None] * num
elif not isinstance(self.custom_logit_processor, list):
self.custom_logit_processor = [self.custom_logit_processor] * num
elif self.parallel_sample_num > 1:
raise ValueError(
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
)
def _validate_session_params(self):
"""Validate that session parameters are properly formatted."""
if self.session_params is not None:
if not isinstance(self.session_params, dict) and not isinstance(
self.session_params[0], dict
):
raise ValueError("Session params must be a dict or a list of dicts.")
def regenerate_rid(self):
"""Generate a new request ID and return it."""
self.rid = uuid.uuid4().hex
return self.rid
@@ -305,8 +440,15 @@ class TokenizedGenerateReqInput:
class EmbeddingReqInput:
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
# The image input. It can be a file name, a url, or base64 encoded string.
image_data: Optional[Union[List[str], str]] = None
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
# Can be formatted as:
# - Single image for a single request
# - List of images (one per request in a batch)
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
] = None
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The request id.