[feat&refactor] Enhance multimodal input support with refactor io_struct (#4938)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from PIL.Image import Image
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -135,9 +136,19 @@ class Engine:
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
# - List of images (one per request in a batch)
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
]
|
||||
] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
@@ -190,9 +201,19 @@ class Engine:
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
# - List of images (one per request in a batch)
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
]
|
||||
] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
@@ -228,7 +249,13 @@ class Engine:
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
]
|
||||
] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL.Image import Image
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||
|
||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||
@@ -56,9 +57,19 @@ class VerlEngine:
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
# - List of images (one per request in a batch)
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
]
|
||||
] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
|
||||
@@ -20,7 +20,13 @@ import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
# handle serialization of Image for pydantic
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
else:
|
||||
Image = Any
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -42,10 +48,16 @@ class GenerateReqInput:
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
# - List of images (one per request in a batch)
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
||||
] = None
|
||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
audio_data: Optional[Union[List[str], str]] = None
|
||||
# The sampling_params. See descriptions below.
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
@@ -84,6 +96,31 @@ class GenerateReqInput:
|
||||
return_hidden_states: bool = False
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
"""
|
||||
Normalize the batch size and arguments for the request.
|
||||
|
||||
This method resolves various input formats and ensures all parameters
|
||||
are properly formatted as either single values or batches depending on the input.
|
||||
It also handles parallel sampling expansion and sets default values for
|
||||
unspecified parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If inputs are not properly specified (e.g., none or all of
|
||||
text, input_ids, input_embeds are provided)
|
||||
"""
|
||||
self._validate_inputs()
|
||||
self._determine_batch_size()
|
||||
self._handle_parallel_sampling()
|
||||
|
||||
if self.is_single:
|
||||
self._normalize_single_inputs()
|
||||
else:
|
||||
self._normalize_batch_inputs()
|
||||
|
||||
self._validate_session_params()
|
||||
|
||||
def _validate_inputs(self):
|
||||
"""Validate that the input configuration is valid."""
|
||||
if (
|
||||
self.text is None and self.input_ids is None and self.input_embeds is None
|
||||
) or (
|
||||
@@ -95,7 +132,8 @@ class GenerateReqInput:
|
||||
"Either text, input_ids or input_embeds should be provided."
|
||||
)
|
||||
|
||||
# Derive the batch size
|
||||
def _determine_batch_size(self):
|
||||
"""Determine if this is a single example or a batch and the batch size."""
|
||||
if self.text is not None:
|
||||
if isinstance(self.text, str):
|
||||
self.is_single = True
|
||||
@@ -119,21 +157,25 @@ class GenerateReqInput:
|
||||
self.is_single = True
|
||||
self.batch_size = 1
|
||||
else:
|
||||
self.is_single = False
|
||||
self.batch_size = len(self.input_embeds)
|
||||
|
||||
# Handle parallel sampling
|
||||
# When parallel sampling is used, we always treat the input as a batch.
|
||||
def _handle_parallel_sampling(self):
|
||||
"""Handle parallel sampling parameters and adjust batch size if needed."""
|
||||
# Determine parallel sample count
|
||||
if self.sampling_params is None:
|
||||
self.parallel_sample_num = 1
|
||||
elif isinstance(self.sampling_params, dict):
|
||||
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
||||
else: # isinstance(self.sampling_params, list):
|
||||
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
||||
assert all(
|
||||
self.parallel_sample_num == sampling_params.get("n", 1)
|
||||
for sampling_params in self.sampling_params
|
||||
), "The parallel_sample_num should be the same for all samples in sample params."
|
||||
for sampling_params in self.sampling_params:
|
||||
if self.parallel_sample_num != sampling_params.get("n", 1):
|
||||
raise ValueError(
|
||||
"The parallel_sample_num should be the same for all samples in sample params."
|
||||
)
|
||||
|
||||
# If using parallel sampling with a single example, convert to batch
|
||||
if self.parallel_sample_num > 1 and self.is_single:
|
||||
self.is_single = False
|
||||
if self.text is not None:
|
||||
@@ -141,97 +183,190 @@ class GenerateReqInput:
|
||||
if self.input_ids is not None:
|
||||
self.input_ids = [self.input_ids]
|
||||
|
||||
# Fill in default arguments
|
||||
if self.is_single:
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = {}
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.return_logprob is None:
|
||||
self.return_logprob = False
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = -1
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = 0
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = None
|
||||
def _normalize_single_inputs(self):
|
||||
"""Normalize inputs for a single example."""
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = {}
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.return_logprob is None:
|
||||
self.return_logprob = False
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = -1
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = 0
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = None
|
||||
|
||||
def _normalize_batch_inputs(self):
|
||||
"""Normalize inputs for a batch of examples, including parallel sampling expansion."""
|
||||
# Calculate expanded batch size
|
||||
if self.parallel_sample_num == 1:
|
||||
num = self.batch_size
|
||||
else:
|
||||
if self.parallel_sample_num == 1:
|
||||
num = self.batch_size
|
||||
# Expand parallel_sample_num
|
||||
num = self.batch_size * self.parallel_sample_num
|
||||
|
||||
# Expand input based on type
|
||||
self._expand_inputs(num)
|
||||
self._normalize_lora_paths(num)
|
||||
self._normalize_image_data(num)
|
||||
self._normalize_audio_data(num)
|
||||
self._normalize_sampling_params(num)
|
||||
self._normalize_rid(num)
|
||||
self._normalize_logprob_params(num)
|
||||
self._normalize_custom_logit_processor(num)
|
||||
|
||||
def _expand_inputs(self, num):
|
||||
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
|
||||
if self.text is not None:
|
||||
if not isinstance(self.text, list):
|
||||
raise ValueError("Text should be a list for batch processing.")
|
||||
self.text = self.text * self.parallel_sample_num
|
||||
elif self.input_ids is not None:
|
||||
if not isinstance(self.input_ids, list) or not isinstance(
|
||||
self.input_ids[0], list
|
||||
):
|
||||
raise ValueError(
|
||||
"input_ids should be a list of lists for batch processing."
|
||||
)
|
||||
self.input_ids = self.input_ids * self.parallel_sample_num
|
||||
elif self.input_embeds is not None:
|
||||
if not isinstance(self.input_embeds, list):
|
||||
raise ValueError("input_embeds should be a list for batch processing.")
|
||||
self.input_embeds = self.input_embeds * self.parallel_sample_num
|
||||
|
||||
def _normalize_lora_paths(self, num):
|
||||
"""Normalize LoRA paths for batch processing."""
|
||||
if self.lora_path is not None:
|
||||
if isinstance(self.lora_path, str):
|
||||
self.lora_path = [self.lora_path] * num
|
||||
elif isinstance(self.lora_path, list):
|
||||
self.lora_path = self.lora_path * self.parallel_sample_num
|
||||
else:
|
||||
raise ValueError("lora_path should be a list or a string.")
|
||||
|
||||
def _normalize_image_data(self, num):
|
||||
"""Normalize image data for batch processing."""
|
||||
if self.image_data is None:
|
||||
self.image_data = [None] * num
|
||||
elif not isinstance(self.image_data, list):
|
||||
# Single image, convert to list of single-image lists
|
||||
self.image_data = [[self.image_data]] * num
|
||||
self.modalities = ["image"] * num
|
||||
elif isinstance(self.image_data, list):
|
||||
if len(self.image_data) != self.batch_size:
|
||||
raise ValueError(
|
||||
"The length of image_data should be equal to the batch size."
|
||||
)
|
||||
|
||||
self.modalities = []
|
||||
if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
|
||||
# Already a list of lists, keep as is
|
||||
for i in range(len(self.image_data)):
|
||||
if self.image_data[i] is None or self.image_data[i] == [None]:
|
||||
self.modalities.append(None)
|
||||
elif len(self.image_data[i]) == 1:
|
||||
self.modalities.append("image")
|
||||
elif len(self.image_data[i]) > 1:
|
||||
self.modalities.append("multi-images")
|
||||
# Expand parallel_sample_num
|
||||
num = self.batch_size * self.parallel_sample_num
|
||||
|
||||
if not self.image_data:
|
||||
self.image_data = [None] * num
|
||||
elif not isinstance(self.image_data, list):
|
||||
self.image_data = [self.image_data] * num
|
||||
elif isinstance(self.image_data, list):
|
||||
pass
|
||||
|
||||
if self.audio_data is None:
|
||||
self.audio_data = [None] * num
|
||||
elif not isinstance(self.audio_data, list):
|
||||
self.audio_data = [self.audio_data] * num
|
||||
elif isinstance(self.audio_data, list):
|
||||
pass
|
||||
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * num
|
||||
elif not isinstance(self.sampling_params, list):
|
||||
self.sampling_params = [self.sampling_params] * num
|
||||
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
||||
self.image_data = self.image_data * self.parallel_sample_num
|
||||
self.modalities = self.modalities * self.parallel_sample_num
|
||||
else:
|
||||
assert isinstance(self.rid, list), "The rid should be a list."
|
||||
# List of images for a batch, wrap each in a list
|
||||
wrapped_images = [[img] for img in self.image_data]
|
||||
# Expand for parallel sampling
|
||||
self.image_data = wrapped_images * self.parallel_sample_num
|
||||
self.modalities = ["image"] * num
|
||||
|
||||
if self.return_logprob is None:
|
||||
self.return_logprob = [False] * num
|
||||
elif not isinstance(self.return_logprob, list):
|
||||
self.return_logprob = [self.return_logprob] * num
|
||||
def _normalize_audio_data(self, num):
|
||||
"""Normalize audio data for batch processing."""
|
||||
if self.audio_data is None:
|
||||
self.audio_data = [None] * num
|
||||
elif not isinstance(self.audio_data, list):
|
||||
self.audio_data = [self.audio_data] * num
|
||||
elif isinstance(self.audio_data, list):
|
||||
self.audio_data = self.audio_data * self.parallel_sample_num
|
||||
|
||||
def _normalize_sampling_params(self, num):
|
||||
"""Normalize sampling parameters for batch processing."""
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * num
|
||||
elif isinstance(self.sampling_params, dict):
|
||||
self.sampling_params = [self.sampling_params] * num
|
||||
else: # Already a list
|
||||
self.sampling_params = self.sampling_params * self.parallel_sample_num
|
||||
|
||||
def _normalize_rid(self, num):
|
||||
"""Normalize request IDs for batch processing."""
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
||||
elif not isinstance(self.rid, list):
|
||||
raise ValueError("The rid should be a list for batch processing.")
|
||||
|
||||
def _normalize_logprob_params(self, num):
|
||||
"""Normalize logprob-related parameters for batch processing."""
|
||||
|
||||
# Helper function to normalize a parameter
|
||||
def normalize_param(param, default_value, param_name):
|
||||
if param is None:
|
||||
return [default_value] * num
|
||||
elif not isinstance(param, list):
|
||||
return [param] * num
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
if self.parallel_sample_num > 1:
|
||||
raise ValueError(
|
||||
f"Cannot use list {param_name} with parallel_sample_num > 1"
|
||||
)
|
||||
return param
|
||||
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = [-1] * num
|
||||
elif not isinstance(self.logprob_start_len, list):
|
||||
self.logprob_start_len = [self.logprob_start_len] * num
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
# Normalize each logprob parameter
|
||||
self.return_logprob = normalize_param(
|
||||
self.return_logprob, False, "return_logprob"
|
||||
)
|
||||
self.logprob_start_len = normalize_param(
|
||||
self.logprob_start_len, -1, "logprob_start_len"
|
||||
)
|
||||
self.top_logprobs_num = normalize_param(
|
||||
self.top_logprobs_num, 0, "top_logprobs_num"
|
||||
)
|
||||
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = [0] * num
|
||||
elif not isinstance(self.top_logprobs_num, list):
|
||||
self.top_logprobs_num = [self.top_logprobs_num] * num
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = [None] * num
|
||||
elif not isinstance(self.token_ids_logprob, list):
|
||||
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
|
||||
elif not isinstance(self.token_ids_logprob[0], list):
|
||||
self.token_ids_logprob = [
|
||||
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
||||
]
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
if self.custom_logit_processor is None:
|
||||
self.custom_logit_processor = [None] * num
|
||||
elif not isinstance(self.custom_logit_processor, list):
|
||||
self.custom_logit_processor = [self.custom_logit_processor] * num
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
# Other checks
|
||||
if self.session_params is not None:
|
||||
assert isinstance(self.session_params, dict) or isinstance(
|
||||
self.session_params[0], dict
|
||||
# Handle token_ids_logprob specially due to its nested structure
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = [None] * num
|
||||
elif not isinstance(self.token_ids_logprob, list):
|
||||
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
|
||||
elif not isinstance(self.token_ids_logprob[0], list):
|
||||
self.token_ids_logprob = [
|
||||
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
||||
]
|
||||
elif self.parallel_sample_num > 1:
|
||||
raise ValueError(
|
||||
"Cannot use list token_ids_logprob with parallel_sample_num > 1"
|
||||
)
|
||||
|
||||
def _normalize_custom_logit_processor(self, num):
|
||||
"""Normalize custom logit processor for batch processing."""
|
||||
if self.custom_logit_processor is None:
|
||||
self.custom_logit_processor = [None] * num
|
||||
elif not isinstance(self.custom_logit_processor, list):
|
||||
self.custom_logit_processor = [self.custom_logit_processor] * num
|
||||
elif self.parallel_sample_num > 1:
|
||||
raise ValueError(
|
||||
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
|
||||
)
|
||||
|
||||
def _validate_session_params(self):
|
||||
"""Validate that session parameters are properly formatted."""
|
||||
if self.session_params is not None:
|
||||
if not isinstance(self.session_params, dict) and not isinstance(
|
||||
self.session_params[0], dict
|
||||
):
|
||||
raise ValueError("Session params must be a dict or a list of dicts.")
|
||||
|
||||
def regenerate_rid(self):
|
||||
"""Generate a new request ID and return it."""
|
||||
self.rid = uuid.uuid4().hex
|
||||
return self.rid
|
||||
|
||||
@@ -305,8 +440,15 @@ class TokenizedGenerateReqInput:
|
||||
class EmbeddingReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
text: Optional[Union[List[str], str]] = None
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
# - List of images (one per request in a batch)
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
||||
] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The request id.
|
||||
|
||||
Reference in New Issue
Block a user