[feat&refactor] Enhance multimodal input support with refactor io_struct (#4938)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from PIL.Image import Image
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -135,9 +136,19 @@ class Engine:
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
# - List of images (one per request in a batch)
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
]
|
||||
] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
@@ -190,9 +201,19 @@ class Engine:
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
# - List of images (one per request in a batch)
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
]
|
||||
] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
@@ -228,7 +249,13 @@ class Engine:
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
]
|
||||
] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL.Image import Image
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||
|
||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||
@@ -56,9 +57,19 @@ class VerlEngine:
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
# - List of images (one per request in a batch)
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
]
|
||||
] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
|
||||
@@ -20,7 +20,13 @@ import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
# handle serialization of Image for pydantic
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
else:
|
||||
Image = Any
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -42,10 +48,16 @@ class GenerateReqInput:
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
# - List of images (one per request in a batch)
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
||||
] = None
|
||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
audio_data: Optional[Union[List[str], str]] = None
|
||||
# The sampling_params. See descriptions below.
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
@@ -84,6 +96,31 @@ class GenerateReqInput:
|
||||
return_hidden_states: bool = False
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
"""
|
||||
Normalize the batch size and arguments for the request.
|
||||
|
||||
This method resolves various input formats and ensures all parameters
|
||||
are properly formatted as either single values or batches depending on the input.
|
||||
It also handles parallel sampling expansion and sets default values for
|
||||
unspecified parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If inputs are not properly specified (e.g., none or all of
|
||||
text, input_ids, input_embeds are provided)
|
||||
"""
|
||||
self._validate_inputs()
|
||||
self._determine_batch_size()
|
||||
self._handle_parallel_sampling()
|
||||
|
||||
if self.is_single:
|
||||
self._normalize_single_inputs()
|
||||
else:
|
||||
self._normalize_batch_inputs()
|
||||
|
||||
self._validate_session_params()
|
||||
|
||||
def _validate_inputs(self):
|
||||
"""Validate that the input configuration is valid."""
|
||||
if (
|
||||
self.text is None and self.input_ids is None and self.input_embeds is None
|
||||
) or (
|
||||
@@ -95,7 +132,8 @@ class GenerateReqInput:
|
||||
"Either text, input_ids or input_embeds should be provided."
|
||||
)
|
||||
|
||||
# Derive the batch size
|
||||
def _determine_batch_size(self):
|
||||
"""Determine if this is a single example or a batch and the batch size."""
|
||||
if self.text is not None:
|
||||
if isinstance(self.text, str):
|
||||
self.is_single = True
|
||||
@@ -119,21 +157,25 @@ class GenerateReqInput:
|
||||
self.is_single = True
|
||||
self.batch_size = 1
|
||||
else:
|
||||
self.is_single = False
|
||||
self.batch_size = len(self.input_embeds)
|
||||
|
||||
# Handle parallel sampling
|
||||
# When parallel sampling is used, we always treat the input as a batch.
|
||||
def _handle_parallel_sampling(self):
|
||||
"""Handle parallel sampling parameters and adjust batch size if needed."""
|
||||
# Determine parallel sample count
|
||||
if self.sampling_params is None:
|
||||
self.parallel_sample_num = 1
|
||||
elif isinstance(self.sampling_params, dict):
|
||||
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
||||
else: # isinstance(self.sampling_params, list):
|
||||
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
||||
assert all(
|
||||
self.parallel_sample_num == sampling_params.get("n", 1)
|
||||
for sampling_params in self.sampling_params
|
||||
), "The parallel_sample_num should be the same for all samples in sample params."
|
||||
for sampling_params in self.sampling_params:
|
||||
if self.parallel_sample_num != sampling_params.get("n", 1):
|
||||
raise ValueError(
|
||||
"The parallel_sample_num should be the same for all samples in sample params."
|
||||
)
|
||||
|
||||
# If using parallel sampling with a single example, convert to batch
|
||||
if self.parallel_sample_num > 1 and self.is_single:
|
||||
self.is_single = False
|
||||
if self.text is not None:
|
||||
@@ -141,97 +183,190 @@ class GenerateReqInput:
|
||||
if self.input_ids is not None:
|
||||
self.input_ids = [self.input_ids]
|
||||
|
||||
# Fill in default arguments
|
||||
if self.is_single:
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = {}
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.return_logprob is None:
|
||||
self.return_logprob = False
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = -1
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = 0
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = None
|
||||
def _normalize_single_inputs(self):
|
||||
"""Normalize inputs for a single example."""
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = {}
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.return_logprob is None:
|
||||
self.return_logprob = False
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = -1
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = 0
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = None
|
||||
|
||||
def _normalize_batch_inputs(self):
|
||||
"""Normalize inputs for a batch of examples, including parallel sampling expansion."""
|
||||
# Calculate expanded batch size
|
||||
if self.parallel_sample_num == 1:
|
||||
num = self.batch_size
|
||||
else:
|
||||
if self.parallel_sample_num == 1:
|
||||
num = self.batch_size
|
||||
# Expand parallel_sample_num
|
||||
num = self.batch_size * self.parallel_sample_num
|
||||
|
||||
# Expand input based on type
|
||||
self._expand_inputs(num)
|
||||
self._normalize_lora_paths(num)
|
||||
self._normalize_image_data(num)
|
||||
self._normalize_audio_data(num)
|
||||
self._normalize_sampling_params(num)
|
||||
self._normalize_rid(num)
|
||||
self._normalize_logprob_params(num)
|
||||
self._normalize_custom_logit_processor(num)
|
||||
|
||||
def _expand_inputs(self, num):
|
||||
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
|
||||
if self.text is not None:
|
||||
if not isinstance(self.text, list):
|
||||
raise ValueError("Text should be a list for batch processing.")
|
||||
self.text = self.text * self.parallel_sample_num
|
||||
elif self.input_ids is not None:
|
||||
if not isinstance(self.input_ids, list) or not isinstance(
|
||||
self.input_ids[0], list
|
||||
):
|
||||
raise ValueError(
|
||||
"input_ids should be a list of lists for batch processing."
|
||||
)
|
||||
self.input_ids = self.input_ids * self.parallel_sample_num
|
||||
elif self.input_embeds is not None:
|
||||
if not isinstance(self.input_embeds, list):
|
||||
raise ValueError("input_embeds should be a list for batch processing.")
|
||||
self.input_embeds = self.input_embeds * self.parallel_sample_num
|
||||
|
||||
def _normalize_lora_paths(self, num):
|
||||
"""Normalize LoRA paths for batch processing."""
|
||||
if self.lora_path is not None:
|
||||
if isinstance(self.lora_path, str):
|
||||
self.lora_path = [self.lora_path] * num
|
||||
elif isinstance(self.lora_path, list):
|
||||
self.lora_path = self.lora_path * self.parallel_sample_num
|
||||
else:
|
||||
raise ValueError("lora_path should be a list or a string.")
|
||||
|
||||
def _normalize_image_data(self, num):
|
||||
"""Normalize image data for batch processing."""
|
||||
if self.image_data is None:
|
||||
self.image_data = [None] * num
|
||||
elif not isinstance(self.image_data, list):
|
||||
# Single image, convert to list of single-image lists
|
||||
self.image_data = [[self.image_data]] * num
|
||||
self.modalities = ["image"] * num
|
||||
elif isinstance(self.image_data, list):
|
||||
if len(self.image_data) != self.batch_size:
|
||||
raise ValueError(
|
||||
"The length of image_data should be equal to the batch size."
|
||||
)
|
||||
|
||||
self.modalities = []
|
||||
if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
|
||||
# Already a list of lists, keep as is
|
||||
for i in range(len(self.image_data)):
|
||||
if self.image_data[i] is None or self.image_data[i] == [None]:
|
||||
self.modalities.append(None)
|
||||
elif len(self.image_data[i]) == 1:
|
||||
self.modalities.append("image")
|
||||
elif len(self.image_data[i]) > 1:
|
||||
self.modalities.append("multi-images")
|
||||
# Expand parallel_sample_num
|
||||
num = self.batch_size * self.parallel_sample_num
|
||||
|
||||
if not self.image_data:
|
||||
self.image_data = [None] * num
|
||||
elif not isinstance(self.image_data, list):
|
||||
self.image_data = [self.image_data] * num
|
||||
elif isinstance(self.image_data, list):
|
||||
pass
|
||||
|
||||
if self.audio_data is None:
|
||||
self.audio_data = [None] * num
|
||||
elif not isinstance(self.audio_data, list):
|
||||
self.audio_data = [self.audio_data] * num
|
||||
elif isinstance(self.audio_data, list):
|
||||
pass
|
||||
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * num
|
||||
elif not isinstance(self.sampling_params, list):
|
||||
self.sampling_params = [self.sampling_params] * num
|
||||
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
||||
self.image_data = self.image_data * self.parallel_sample_num
|
||||
self.modalities = self.modalities * self.parallel_sample_num
|
||||
else:
|
||||
assert isinstance(self.rid, list), "The rid should be a list."
|
||||
# List of images for a batch, wrap each in a list
|
||||
wrapped_images = [[img] for img in self.image_data]
|
||||
# Expand for parallel sampling
|
||||
self.image_data = wrapped_images * self.parallel_sample_num
|
||||
self.modalities = ["image"] * num
|
||||
|
||||
if self.return_logprob is None:
|
||||
self.return_logprob = [False] * num
|
||||
elif not isinstance(self.return_logprob, list):
|
||||
self.return_logprob = [self.return_logprob] * num
|
||||
def _normalize_audio_data(self, num):
|
||||
"""Normalize audio data for batch processing."""
|
||||
if self.audio_data is None:
|
||||
self.audio_data = [None] * num
|
||||
elif not isinstance(self.audio_data, list):
|
||||
self.audio_data = [self.audio_data] * num
|
||||
elif isinstance(self.audio_data, list):
|
||||
self.audio_data = self.audio_data * self.parallel_sample_num
|
||||
|
||||
def _normalize_sampling_params(self, num):
|
||||
"""Normalize sampling parameters for batch processing."""
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * num
|
||||
elif isinstance(self.sampling_params, dict):
|
||||
self.sampling_params = [self.sampling_params] * num
|
||||
else: # Already a list
|
||||
self.sampling_params = self.sampling_params * self.parallel_sample_num
|
||||
|
||||
def _normalize_rid(self, num):
|
||||
"""Normalize request IDs for batch processing."""
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
||||
elif not isinstance(self.rid, list):
|
||||
raise ValueError("The rid should be a list for batch processing.")
|
||||
|
||||
def _normalize_logprob_params(self, num):
|
||||
"""Normalize logprob-related parameters for batch processing."""
|
||||
|
||||
# Helper function to normalize a parameter
|
||||
def normalize_param(param, default_value, param_name):
|
||||
if param is None:
|
||||
return [default_value] * num
|
||||
elif not isinstance(param, list):
|
||||
return [param] * num
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
if self.parallel_sample_num > 1:
|
||||
raise ValueError(
|
||||
f"Cannot use list {param_name} with parallel_sample_num > 1"
|
||||
)
|
||||
return param
|
||||
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = [-1] * num
|
||||
elif not isinstance(self.logprob_start_len, list):
|
||||
self.logprob_start_len = [self.logprob_start_len] * num
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
# Normalize each logprob parameter
|
||||
self.return_logprob = normalize_param(
|
||||
self.return_logprob, False, "return_logprob"
|
||||
)
|
||||
self.logprob_start_len = normalize_param(
|
||||
self.logprob_start_len, -1, "logprob_start_len"
|
||||
)
|
||||
self.top_logprobs_num = normalize_param(
|
||||
self.top_logprobs_num, 0, "top_logprobs_num"
|
||||
)
|
||||
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = [0] * num
|
||||
elif not isinstance(self.top_logprobs_num, list):
|
||||
self.top_logprobs_num = [self.top_logprobs_num] * num
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = [None] * num
|
||||
elif not isinstance(self.token_ids_logprob, list):
|
||||
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
|
||||
elif not isinstance(self.token_ids_logprob[0], list):
|
||||
self.token_ids_logprob = [
|
||||
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
||||
]
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
if self.custom_logit_processor is None:
|
||||
self.custom_logit_processor = [None] * num
|
||||
elif not isinstance(self.custom_logit_processor, list):
|
||||
self.custom_logit_processor = [self.custom_logit_processor] * num
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
# Other checks
|
||||
if self.session_params is not None:
|
||||
assert isinstance(self.session_params, dict) or isinstance(
|
||||
self.session_params[0], dict
|
||||
# Handle token_ids_logprob specially due to its nested structure
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = [None] * num
|
||||
elif not isinstance(self.token_ids_logprob, list):
|
||||
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
|
||||
elif not isinstance(self.token_ids_logprob[0], list):
|
||||
self.token_ids_logprob = [
|
||||
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
||||
]
|
||||
elif self.parallel_sample_num > 1:
|
||||
raise ValueError(
|
||||
"Cannot use list token_ids_logprob with parallel_sample_num > 1"
|
||||
)
|
||||
|
||||
def _normalize_custom_logit_processor(self, num):
|
||||
"""Normalize custom logit processor for batch processing."""
|
||||
if self.custom_logit_processor is None:
|
||||
self.custom_logit_processor = [None] * num
|
||||
elif not isinstance(self.custom_logit_processor, list):
|
||||
self.custom_logit_processor = [self.custom_logit_processor] * num
|
||||
elif self.parallel_sample_num > 1:
|
||||
raise ValueError(
|
||||
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
|
||||
)
|
||||
|
||||
def _validate_session_params(self):
|
||||
"""Validate that session parameters are properly formatted."""
|
||||
if self.session_params is not None:
|
||||
if not isinstance(self.session_params, dict) and not isinstance(
|
||||
self.session_params[0], dict
|
||||
):
|
||||
raise ValueError("Session params must be a dict or a list of dicts.")
|
||||
|
||||
def regenerate_rid(self):
|
||||
"""Generate a new request ID and return it."""
|
||||
self.rid = uuid.uuid4().hex
|
||||
return self.rid
|
||||
|
||||
@@ -305,8 +440,15 @@ class TokenizedGenerateReqInput:
|
||||
class EmbeddingReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
text: Optional[Union[List[str], str]] = None
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
# - List of images (one per request in a batch)
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
||||
] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The request id.
|
||||
|
||||
527
test/srt/test_io_struct.py
Normal file
527
test/srt/test_io_struct.py
Normal file
@@ -0,0 +1,527 @@
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateReqInputNormalization(CustomTestCase):
|
||||
"""Test the normalization of GenerateReqInput for batch processing and different input formats."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for all tests
|
||||
self.base_req = GenerateReqInput(
|
||||
text=["Hello", "World"],
|
||||
sampling_params=[{}, {}],
|
||||
rid=["id1", "id2"],
|
||||
)
|
||||
|
||||
def test_single_image_to_list_of_lists(self):
|
||||
"""Test that a single image is converted to a list of single-image lists."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.image_data = "single_image.jpg" # A single image (non-list)
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should be converted to [[image], [image]]
|
||||
self.assertEqual(len(req.image_data), 2)
|
||||
self.assertEqual(len(req.image_data[0]), 1)
|
||||
self.assertEqual(len(req.image_data[1]), 1)
|
||||
self.assertEqual(req.image_data[0][0], "single_image.jpg")
|
||||
self.assertEqual(req.image_data[1][0], "single_image.jpg")
|
||||
|
||||
# Check modalities
|
||||
self.assertEqual(req.modalities, ["image", "image"])
|
||||
|
||||
def test_list_of_images_to_list_of_lists(self):
|
||||
"""Test that a list of images is converted to a list of single-image lists."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.image_data = ["image1.jpg", "image2.jpg"] # List of images
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should be converted to [[image1], [image2]]
|
||||
self.assertEqual(len(req.image_data), 2)
|
||||
self.assertEqual(len(req.image_data[0]), 1)
|
||||
self.assertEqual(len(req.image_data[1]), 1)
|
||||
self.assertEqual(req.image_data[0][0], "image1.jpg")
|
||||
self.assertEqual(req.image_data[1][0], "image2.jpg")
|
||||
|
||||
# Check modalities
|
||||
self.assertEqual(req.modalities, ["image", "image"])
|
||||
|
||||
def test_list_of_lists_with_different_modalities(self):
|
||||
"""Test handling of list of lists of images with different modalities."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.image_data = [
|
||||
["image1.jpg"], # Single image (image modality)
|
||||
["image2.jpg", "image3.jpg"], # Multiple images (multi-images modality)
|
||||
]
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Structure should remain the same
|
||||
self.assertEqual(len(req.image_data), 2)
|
||||
self.assertEqual(len(req.image_data[0]), 1)
|
||||
self.assertEqual(len(req.image_data[1]), 2)
|
||||
|
||||
# Check modalities
|
||||
self.assertEqual(req.modalities, ["image", "multi-images"])
|
||||
|
||||
def test_list_of_lists_with_none_values(self):
|
||||
"""Test handling of list of lists with None values."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.image_data = [
|
||||
[None], # None value
|
||||
["image.jpg"], # Single image
|
||||
]
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Structure should remain the same
|
||||
self.assertEqual(len(req.image_data), 2)
|
||||
self.assertEqual(len(req.image_data[0]), 1)
|
||||
self.assertEqual(len(req.image_data[1]), 1)
|
||||
|
||||
# Check modalities
|
||||
self.assertEqual(req.modalities, [None, "image"])
|
||||
|
||||
def test_expanding_parallel_sample_correlation(self):
|
||||
"""Test that when expanding with parallel samples, prompts, images and modalities are properly correlated."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.text = ["Prompt 1", "Prompt 2"]
|
||||
req.image_data = [
|
||||
["image1.jpg"],
|
||||
["image2.jpg", "image3.jpg"],
|
||||
]
|
||||
req.sampling_params = {"n": 3} # All prompts get 3 samples
|
||||
|
||||
# Define expected values before normalization
|
||||
expected_text = req.text * 3
|
||||
expected_images = req.image_data * 3
|
||||
expected_modalities = ["image", "multi-images"] * 3
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should be expanded to 6 items (2 original * 3 parallel)
|
||||
self.assertEqual(len(req.image_data), 6)
|
||||
|
||||
# Check that images are properly expanded
|
||||
self.assertEqual(req.image_data, expected_images)
|
||||
|
||||
# Check modalities
|
||||
self.assertEqual(req.modalities, expected_modalities)
|
||||
|
||||
# Ensure that text items are properly duplicated too
|
||||
self.assertEqual(req.text, expected_text)
|
||||
|
||||
def test_list_of_lists_with_none_values(self):
|
||||
"""Test handling of list of lists with None values."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.image_data = [
|
||||
[None], # None value
|
||||
["image.jpg"], # Single image
|
||||
]
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Structure should remain the same
|
||||
self.assertEqual(len(req.image_data), 2)
|
||||
self.assertEqual(len(req.image_data[0]), 1)
|
||||
self.assertEqual(len(req.image_data[1]), 1)
|
||||
|
||||
# Check modalities
|
||||
self.assertEqual(req.modalities, [None, "image"])
|
||||
|
||||
def test_specific_parallel_n_per_sample(self):
|
||||
"""Test parallel expansion when different samples have different n values."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.text = ["Prompt 1", "Prompt 2"]
|
||||
req.image_data = [
|
||||
["image1.jpg"],
|
||||
["image2.jpg", "image3.jpg"],
|
||||
]
|
||||
req.sampling_params = [
|
||||
{"n": 2},
|
||||
{"n": 2},
|
||||
] # First prompt gets 2 samples, second prompt gets 2 samples
|
||||
|
||||
expected_images = req.image_data * 2
|
||||
expected_modalities = ["image", "multi-images"] * 2
|
||||
expected_text = req.text * 2
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should be expanded to 4 items (2 original * 2 parallel)
|
||||
self.assertEqual(len(req.image_data), 4)
|
||||
|
||||
# Check that the first 2 are copies for the first prompt
|
||||
self.assertEqual(req.image_data, expected_images)
|
||||
|
||||
# Check modalities
|
||||
self.assertEqual(req.modalities, expected_modalities)
|
||||
|
||||
# Check text expansion
|
||||
self.assertEqual(req.text, expected_text)
|
||||
|
||||
def test_mixed_none_and_images_with_parallel_samples(self):
|
||||
"""Test that when some batch items have images and others None, parallel expansion works correctly."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.text = ["Prompt 1", "Prompt 2", "Prompt 3"]
|
||||
req.image_data = [
|
||||
["image1.jpg"],
|
||||
None,
|
||||
["image3_1.jpg", "image3_2.jpg"],
|
||||
]
|
||||
req.sampling_params = {"n": 2} # All prompts get 2 samples
|
||||
|
||||
expected_images = req.image_data * 2
|
||||
expected_modalities = ["image", None, "multi-images"] * 2
|
||||
expected_text = req.text * 2
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should be expanded to 6 items (3 original * 2 parallel)
|
||||
self.assertEqual(len(req.image_data), 6)
|
||||
|
||||
# Check image data
|
||||
self.assertEqual(req.image_data, expected_images)
|
||||
|
||||
# Check modalities
|
||||
self.assertEqual(req.modalities, expected_modalities)
|
||||
|
||||
# Check text expansion
|
||||
self.assertEqual(req.text, expected_text)
|
||||
|
||||
def test_correlation_with_sampling_params(self):
|
||||
"""Test that sampling parameters are correctly correlated with prompts during expansion."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.text = ["Prompt 1", "Prompt 2"]
|
||||
req.image_data = [
|
||||
["image1.jpg"],
|
||||
["image2.jpg"],
|
||||
]
|
||||
req.sampling_params = [
|
||||
{"temperature": 0.7, "n": 2},
|
||||
{"temperature": 0.9, "n": 2},
|
||||
]
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Check sampling params expansion
|
||||
self.assertEqual(len(req.sampling_params), 4)
|
||||
self.assertEqual(req.sampling_params[0]["temperature"], 0.7)
|
||||
self.assertEqual(req.sampling_params[1]["temperature"], 0.9)
|
||||
self.assertEqual(req.sampling_params[2]["temperature"], 0.7)
|
||||
self.assertEqual(req.sampling_params[3]["temperature"], 0.9)
|
||||
|
||||
# Should be expanded to 4 items (2 original * 2 parallel)
|
||||
self.assertEqual(len(req.image_data), 4)
|
||||
|
||||
# Check correlation with images
|
||||
self.assertEqual(req.image_data[0], ["image1.jpg"])
|
||||
self.assertEqual(req.image_data[1], ["image2.jpg"])
|
||||
self.assertEqual(req.image_data[2], ["image1.jpg"])
|
||||
self.assertEqual(req.image_data[3], ["image2.jpg"])
|
||||
|
||||
def test_single_example_with_image(self):
|
||||
"""Test handling of single example with image."""
|
||||
req = GenerateReqInput(
|
||||
text="Hello",
|
||||
image_data="single_image.jpg",
|
||||
)
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# For single examples, image_data doesn't get processed into lists
|
||||
self.assertEqual(req.image_data, "single_image.jpg")
|
||||
self.assertIsNone(req.modalities) # Modalities isn't set for single examples
|
||||
|
||||
def test_single_to_batch_with_parallel_sampling(self):
|
||||
"""Test single example converted to batch with parallel sampling."""
|
||||
req = GenerateReqInput(
|
||||
text="Hello",
|
||||
image_data="single_image.jpg",
|
||||
sampling_params={"n": 3}, # parallel_sample_num = 3
|
||||
)
|
||||
|
||||
# Define expected values before normalization
|
||||
expected_text = ["Hello"] * 3
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should be converted to batch with text=["Hello"]
|
||||
self.assertEqual(req.text, expected_text)
|
||||
|
||||
# Image should be automatically wrapped to list of lists with length 1*3=3
|
||||
self.assertEqual(len(req.image_data), 3)
|
||||
self.assertEqual(req.image_data[0][0], "single_image.jpg")
|
||||
self.assertEqual(req.image_data[1][0], "single_image.jpg")
|
||||
self.assertEqual(req.image_data[2][0], "single_image.jpg")
|
||||
|
||||
# Modalities should be set for all 3 examples
|
||||
self.assertEqual(req.modalities, ["image", "image", "image"])
|
||||
|
||||
def test_audio_data_handling(self):
|
||||
"""Test handling of audio_data."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.audio_data = "audio.mp3" # Single audio
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should be converted to ["audio.mp3", "audio.mp3"]
|
||||
self.assertEqual(len(req.audio_data), 2)
|
||||
self.assertEqual(req.audio_data[0], "audio.mp3")
|
||||
self.assertEqual(req.audio_data[1], "audio.mp3")
|
||||
|
||||
# Test with list
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.audio_data = ["audio1.mp3", "audio2.mp3"]
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should remain the same
|
||||
self.assertEqual(len(req.audio_data), 2)
|
||||
self.assertEqual(req.audio_data[0], "audio1.mp3")
|
||||
self.assertEqual(req.audio_data[1], "audio2.mp3")
|
||||
|
||||
def test_input_ids_normalization(self):
|
||||
"""Test normalization of input_ids instead of text."""
|
||||
# Test single input_ids
|
||||
req = GenerateReqInput(input_ids=[1, 2, 3])
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertTrue(req.is_single)
|
||||
self.assertEqual(req.batch_size, 1)
|
||||
|
||||
# Test batch input_ids
|
||||
req = GenerateReqInput(input_ids=[[1, 2, 3], [4, 5, 6]])
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertFalse(req.is_single)
|
||||
self.assertEqual(req.batch_size, 2)
|
||||
|
||||
# Test with parallel sampling
|
||||
req = GenerateReqInput(
|
||||
input_ids=[[1, 2, 3], [4, 5, 6]], sampling_params={"n": 2}
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(len(req.input_ids), 4) # 2 original * 2 parallel
|
||||
|
||||
def test_input_embeds_normalization(self):
|
||||
"""Test normalization of input_embeds."""
|
||||
# Test single input_embeds
|
||||
req = GenerateReqInput(input_embeds=[[0.1, 0.2], [0.3, 0.4]])
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertTrue(req.is_single)
|
||||
self.assertEqual(req.batch_size, 1)
|
||||
|
||||
# Test batch input_embeds
|
||||
req = GenerateReqInput(input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]])
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertFalse(req.is_single)
|
||||
self.assertEqual(req.batch_size, 2)
|
||||
|
||||
def test_lora_path_normalization(self):
|
||||
"""Test normalization of lora_path."""
|
||||
# Test single lora_path with batch input
|
||||
req = GenerateReqInput(text=["Hello", "World"], lora_path="path/to/lora")
|
||||
|
||||
# Define expected lora_paths before normalization
|
||||
expected_lora_paths = ["path/to/lora", "path/to/lora"]
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(req.lora_path, expected_lora_paths)
|
||||
|
||||
# Test list of lora_paths
|
||||
req = GenerateReqInput(text=["Hello", "World"], lora_path=["path1", "path2"])
|
||||
|
||||
# Define expected lora_paths before normalization
|
||||
expected_lora_paths = ["path1", "path2"]
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(req.lora_path, expected_lora_paths)
|
||||
|
||||
# Test with parallel sampling
|
||||
req = GenerateReqInput(
|
||||
text=["Hello", "World"],
|
||||
lora_path=["path1", "path2"],
|
||||
sampling_params={"n": 2},
|
||||
)
|
||||
|
||||
# Define expected lora_paths before normalization
|
||||
expected_lora_paths = ["path1", "path2"] * 2
|
||||
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(req.lora_path, expected_lora_paths)
|
||||
|
||||
def test_logprob_parameters_normalization(self):
|
||||
"""Test normalization of logprob-related parameters."""
|
||||
# Test single example
|
||||
req = GenerateReqInput(
|
||||
text="Hello",
|
||||
return_logprob=True,
|
||||
logprob_start_len=10,
|
||||
top_logprobs_num=5,
|
||||
token_ids_logprob=[7, 8, 9],
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(req.return_logprob, True)
|
||||
self.assertEqual(req.logprob_start_len, 10)
|
||||
self.assertEqual(req.top_logprobs_num, 5)
|
||||
self.assertEqual(req.token_ids_logprob, [7, 8, 9])
|
||||
|
||||
# Test batch with scalar values
|
||||
req = GenerateReqInput(
|
||||
text=["Hello", "World"],
|
||||
return_logprob=True,
|
||||
logprob_start_len=10,
|
||||
top_logprobs_num=5,
|
||||
token_ids_logprob=[7, 8, 9],
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(req.return_logprob, [True, True])
|
||||
self.assertEqual(req.logprob_start_len, [10, 10])
|
||||
self.assertEqual(req.top_logprobs_num, [5, 5])
|
||||
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [7, 8, 9]])
|
||||
|
||||
# Test batch with list values
|
||||
req = GenerateReqInput(
|
||||
text=["Hello", "World"],
|
||||
return_logprob=[True, False],
|
||||
logprob_start_len=[10, 5],
|
||||
top_logprobs_num=[5, 3],
|
||||
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(req.return_logprob, [True, False])
|
||||
self.assertEqual(req.logprob_start_len, [10, 5])
|
||||
self.assertEqual(req.top_logprobs_num, [5, 3])
|
||||
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]])
|
||||
|
||||
def test_custom_logit_processor_normalization(self):
|
||||
"""Test normalization of custom_logit_processor."""
|
||||
# Test single processor
|
||||
req = GenerateReqInput(
|
||||
text=["Hello", "World"], custom_logit_processor="serialized_processor"
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(
|
||||
req.custom_logit_processor, ["serialized_processor", "serialized_processor"]
|
||||
)
|
||||
|
||||
# Test list of processors
|
||||
req = GenerateReqInput(
|
||||
text=["Hello", "World"], custom_logit_processor=["processor1", "processor2"]
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(req.custom_logit_processor, ["processor1", "processor2"])
|
||||
|
||||
def test_session_params_handling(self):
|
||||
"""Test handling of session_params."""
|
||||
# Test with dict
|
||||
req = GenerateReqInput(
|
||||
text=["Hello", "World"], session_params={"id": "session1", "offset": 10}
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(req.session_params, {"id": "session1", "offset": 10})
|
||||
|
||||
# Test with list of dicts
|
||||
req = GenerateReqInput(
|
||||
text=["Hello", "World"],
|
||||
session_params=[{"id": "session1"}, {"id": "session2"}],
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(req.session_params, [{"id": "session1"}, {"id": "session2"}])
|
||||
|
||||
def test_getitem_method(self):
|
||||
"""Test the __getitem__ method."""
|
||||
req = GenerateReqInput(
|
||||
text=["Hello", "World"],
|
||||
image_data=[["img1.jpg"], ["img2.jpg"]],
|
||||
audio_data=["audio1.mp3", "audio2.mp3"],
|
||||
sampling_params=[{"temp": 0.7}, {"temp": 0.8}],
|
||||
rid=["id1", "id2"],
|
||||
return_logprob=[True, False],
|
||||
logprob_start_len=[10, 5],
|
||||
top_logprobs_num=[5, 3],
|
||||
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
|
||||
stream=True,
|
||||
log_metrics=True,
|
||||
modalities=["image", "image"],
|
||||
lora_path=["path1", "path2"],
|
||||
custom_logit_processor=["processor1", "processor2"],
|
||||
return_hidden_states=True,
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Get the first item
|
||||
item0 = req[0]
|
||||
self.assertEqual(item0.text, "Hello")
|
||||
self.assertEqual(item0.image_data, ["img1.jpg"])
|
||||
self.assertEqual(item0.audio_data, "audio1.mp3")
|
||||
self.assertEqual(item0.sampling_params, {"temp": 0.7})
|
||||
self.assertEqual(item0.rid, "id1")
|
||||
self.assertEqual(item0.return_logprob, True)
|
||||
self.assertEqual(item0.logprob_start_len, 10)
|
||||
self.assertEqual(item0.top_logprobs_num, 5)
|
||||
self.assertEqual(item0.token_ids_logprob, [7, 8, 9])
|
||||
self.assertEqual(item0.stream, True)
|
||||
self.assertEqual(item0.log_metrics, True)
|
||||
self.assertEqual(item0.modalities, "image")
|
||||
self.assertEqual(item0.lora_path, "path1")
|
||||
self.assertEqual(item0.custom_logit_processor, "processor1")
|
||||
self.assertEqual(item0.return_hidden_states, True)
|
||||
|
||||
def test_regenerate_rid(self):
|
||||
"""Test the regenerate_rid method."""
|
||||
req = GenerateReqInput(text="Hello")
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
original_rid = req.rid
|
||||
new_rid = req.regenerate_rid()
|
||||
|
||||
self.assertNotEqual(original_rid, new_rid)
|
||||
self.assertEqual(req.rid, new_rid)
|
||||
|
||||
def test_error_cases(self):
|
||||
"""Test various error cases."""
|
||||
# Test when neither text, input_ids, nor input_embeds is provided
|
||||
with self.assertRaises(ValueError):
|
||||
req = GenerateReqInput()
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Test when all of text, input_ids, and input_embeds are provided
|
||||
with self.assertRaises(ValueError):
|
||||
req = GenerateReqInput(
|
||||
text="Hello", input_ids=[1, 2, 3], input_embeds=[[0.1, 0.2]]
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
def test_multiple_input_formats(self):
|
||||
"""Test different combinations of input formats."""
|
||||
# Test with text only
|
||||
req = GenerateReqInput(text="Hello")
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertTrue(req.is_single)
|
||||
|
||||
# Test with input_ids only
|
||||
req = GenerateReqInput(input_ids=[1, 2, 3])
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertTrue(req.is_single)
|
||||
|
||||
# Test with input_embeds only
|
||||
req = GenerateReqInput(input_embeds=[[0.1, 0.2]])
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertTrue(req.is_single)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user