From d09a51f1f62038ec25225012dec6bb7488eb8054 Mon Sep 17 00:00:00 2001 From: XinyuanTong <115166877+JustinTong0323@users.noreply.github.com> Date: Tue, 8 Apr 2025 14:48:07 -0700 Subject: [PATCH] [feat&refactor] Enhance multimodal input support with refactor io_struct (#4938) Signed-off-by: Xinyuan Tong --- python/sglang/srt/entrypoints/engine.py | 41 +- python/sglang/srt/entrypoints/verl_engine.py | 17 +- python/sglang/srt/managers/io_struct.py | 330 ++++++++---- test/srt/test_io_struct.py | 527 +++++++++++++++++++ 4 files changed, 811 insertions(+), 104 deletions(-) create mode 100644 test/srt/test_io_struct.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index b92c6ecdb..33aab232f 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -29,6 +29,7 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union import zmq import zmq.asyncio +from PIL.Image import Image # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -135,9 +136,19 @@ class Engine: sampling_params: Optional[Union[List[Dict], Dict]] = None, # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None, - # The image input. It can be a file name, a url, or base64 encoded string. - # See also python/sglang/srt/utils.py:load_image. - image_data: Optional[Union[List[str], str]] = None, + # The image input. It can be an image instance, file name, URL, or base64 encoded string. + # Can be formatted as: + # - Single image for a single request + # - List of images (one per request in a batch) + # - List of lists of images (multiple images per request) + # See also python/sglang/srt/utils.py:load_image for more details. + image_data: Optional[ + Union[ + List[List[Union[Image, str]]], + List[Union[Image, str]], + Union[Image, str], + ] + ] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, @@ -190,9 +201,19 @@ class Engine: sampling_params: Optional[Union[List[Dict], Dict]] = None, # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None, - # The image input. It can be a file name, a url, or base64 encoded string. - # See also python/sglang/srt/utils.py:load_image. - image_data: Optional[Union[List[str], str]] = None, + # The image input. It can be an image instance, file name, URL, or base64 encoded string. + # Can be formatted as: + # - Single image for a single request + # - List of images (one per request in a batch) + # - List of lists of images (multiple images per request) + # See also python/sglang/srt/utils.py:load_image for more details. + image_data: Optional[ + Union[ + List[List[Union[Image, str]]], + List[Union[Image, str]], + Union[Image, str], + ] + ] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, @@ -228,7 +249,13 @@ class Engine: def encode( self, prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - image_data: Optional[Union[List[str], str]] = None, + image_data: Optional[ + Union[ + List[List[Union[Image, str]]], + List[Union[Image, str]], + Union[Image, str], + ] + ] = None, ) -> Dict: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index 13f60451e..fa7177153 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist +from PIL.Image import Image from torch.distributed.tensor import DeviceMesh, DTensor from sglang.srt.model_executor.model_runner import LocalSerializedTensor @@ -56,9 +57,19 @@ class VerlEngine: sampling_params: Optional[Union[List[Dict], Dict]] = None, # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None, - # The image input. It can be a file name, a url, or base64 encoded string. - # See also python/sglang/srt/utils.py:load_image. - image_data: Optional[Union[List[str], str]] = None, + # The image input. It can be an image instance, file name, URL, or base64 encoded string. + # Can be formatted as: + # - Single image for a single request + # - List of images (one per request in a batch) + # - List of lists of images (multiple images per request) + # See also python/sglang/srt/utils.py:load_image for more details. + image_data: Optional[ + Union[ + List[List[Union[Image, str]]], + List[Union[Image, str]], + Union[Image, str], + ] + ] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e25a8f242..591b0660f 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -20,7 +20,13 @@ import copy import uuid from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union + +# handle serialization of Image for pydantic +if TYPE_CHECKING: + from PIL.Image import Image +else: + Image = Any from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams @@ -42,10 +48,16 @@ class GenerateReqInput: input_ids: Optional[Union[List[List[int]], List[int]]] = None # The embeddings for input_ids; one can specify either text or input_ids or input_embeds. input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None - # The image input. It can be a file name, a url, or base64 encoded string. - # See also python/sglang/srt/utils.py:load_image. - image_data: Optional[Union[List[str], str]] = None - # The audio input. Like image data, tt can be a file name, a url, or base64 encoded string. + # The image input. It can be an image instance, file name, URL, or base64 encoded string. + # Can be formatted as: + # - Single image for a single request + # - List of images (one per request in a batch) + # - List of lists of images (multiple images per request) + # See also python/sglang/srt/utils.py:load_image for more details. + image_data: Optional[ + Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]] + ] = None + # The audio input. Like image data, it can be a file name, a url, or base64 encoded string. audio_data: Optional[Union[List[str], str]] = None # The sampling_params. See descriptions below. sampling_params: Optional[Union[List[Dict], Dict]] = None @@ -84,6 +96,31 @@ class GenerateReqInput: return_hidden_states: bool = False def normalize_batch_and_arguments(self): + """ + Normalize the batch size and arguments for the request. + + This method resolves various input formats and ensures all parameters + are properly formatted as either single values or batches depending on the input. + It also handles parallel sampling expansion and sets default values for + unspecified parameters. + + Raises: + ValueError: If inputs are not properly specified (e.g., none or all of + text, input_ids, input_embeds are provided) + """ + self._validate_inputs() + self._determine_batch_size() + self._handle_parallel_sampling() + + if self.is_single: + self._normalize_single_inputs() + else: + self._normalize_batch_inputs() + + self._validate_session_params() + + def _validate_inputs(self): + """Validate that the input configuration is valid.""" if ( self.text is None and self.input_ids is None and self.input_embeds is None ) or ( @@ -95,7 +132,8 @@ class GenerateReqInput: "Either text, input_ids or input_embeds should be provided." ) - # Derive the batch size + def _determine_batch_size(self): + """Determine if this is a single example or a batch and the batch size.""" if self.text is not None: if isinstance(self.text, str): self.is_single = True @@ -119,21 +157,25 @@ class GenerateReqInput: self.is_single = True self.batch_size = 1 else: + self.is_single = False self.batch_size = len(self.input_embeds) - # Handle parallel sampling - # When parallel sampling is used, we always treat the input as a batch. + def _handle_parallel_sampling(self): + """Handle parallel sampling parameters and adjust batch size if needed.""" + # Determine parallel sample count if self.sampling_params is None: self.parallel_sample_num = 1 elif isinstance(self.sampling_params, dict): self.parallel_sample_num = self.sampling_params.get("n", 1) else: # isinstance(self.sampling_params, list): self.parallel_sample_num = self.sampling_params[0].get("n", 1) - assert all( - self.parallel_sample_num == sampling_params.get("n", 1) - for sampling_params in self.sampling_params - ), "The parallel_sample_num should be the same for all samples in sample params." + for sampling_params in self.sampling_params: + if self.parallel_sample_num != sampling_params.get("n", 1): + raise ValueError( + "The parallel_sample_num should be the same for all samples in sample params." + ) + # If using parallel sampling with a single example, convert to batch if self.parallel_sample_num > 1 and self.is_single: self.is_single = False if self.text is not None: @@ -141,97 +183,190 @@ class GenerateReqInput: if self.input_ids is not None: self.input_ids = [self.input_ids] - # Fill in default arguments - if self.is_single: - if self.sampling_params is None: - self.sampling_params = {} - if self.rid is None: - self.rid = uuid.uuid4().hex - if self.return_logprob is None: - self.return_logprob = False - if self.logprob_start_len is None: - self.logprob_start_len = -1 - if self.top_logprobs_num is None: - self.top_logprobs_num = 0 - if not self.token_ids_logprob: # covers both None and [] - self.token_ids_logprob = None + def _normalize_single_inputs(self): + """Normalize inputs for a single example.""" + if self.sampling_params is None: + self.sampling_params = {} + if self.rid is None: + self.rid = uuid.uuid4().hex + if self.return_logprob is None: + self.return_logprob = False + if self.logprob_start_len is None: + self.logprob_start_len = -1 + if self.top_logprobs_num is None: + self.top_logprobs_num = 0 + if not self.token_ids_logprob: # covers both None and [] + self.token_ids_logprob = None + + def _normalize_batch_inputs(self): + """Normalize inputs for a batch of examples, including parallel sampling expansion.""" + # Calculate expanded batch size + if self.parallel_sample_num == 1: + num = self.batch_size else: - if self.parallel_sample_num == 1: - num = self.batch_size + # Expand parallel_sample_num + num = self.batch_size * self.parallel_sample_num + + # Expand input based on type + self._expand_inputs(num) + self._normalize_lora_paths(num) + self._normalize_image_data(num) + self._normalize_audio_data(num) + self._normalize_sampling_params(num) + self._normalize_rid(num) + self._normalize_logprob_params(num) + self._normalize_custom_logit_processor(num) + + def _expand_inputs(self, num): + """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling.""" + if self.text is not None: + if not isinstance(self.text, list): + raise ValueError("Text should be a list for batch processing.") + self.text = self.text * self.parallel_sample_num + elif self.input_ids is not None: + if not isinstance(self.input_ids, list) or not isinstance( + self.input_ids[0], list + ): + raise ValueError( + "input_ids should be a list of lists for batch processing." + ) + self.input_ids = self.input_ids * self.parallel_sample_num + elif self.input_embeds is not None: + if not isinstance(self.input_embeds, list): + raise ValueError("input_embeds should be a list for batch processing.") + self.input_embeds = self.input_embeds * self.parallel_sample_num + + def _normalize_lora_paths(self, num): + """Normalize LoRA paths for batch processing.""" + if self.lora_path is not None: + if isinstance(self.lora_path, str): + self.lora_path = [self.lora_path] * num + elif isinstance(self.lora_path, list): + self.lora_path = self.lora_path * self.parallel_sample_num else: + raise ValueError("lora_path should be a list or a string.") + + def _normalize_image_data(self, num): + """Normalize image data for batch processing.""" + if self.image_data is None: + self.image_data = [None] * num + elif not isinstance(self.image_data, list): + # Single image, convert to list of single-image lists + self.image_data = [[self.image_data]] * num + self.modalities = ["image"] * num + elif isinstance(self.image_data, list): + if len(self.image_data) != self.batch_size: + raise ValueError( + "The length of image_data should be equal to the batch size." + ) + + self.modalities = [] + if len(self.image_data) > 0 and isinstance(self.image_data[0], list): + # Already a list of lists, keep as is + for i in range(len(self.image_data)): + if self.image_data[i] is None or self.image_data[i] == [None]: + self.modalities.append(None) + elif len(self.image_data[i]) == 1: + self.modalities.append("image") + elif len(self.image_data[i]) > 1: + self.modalities.append("multi-images") # Expand parallel_sample_num - num = self.batch_size * self.parallel_sample_num - - if not self.image_data: - self.image_data = [None] * num - elif not isinstance(self.image_data, list): - self.image_data = [self.image_data] * num - elif isinstance(self.image_data, list): - pass - - if self.audio_data is None: - self.audio_data = [None] * num - elif not isinstance(self.audio_data, list): - self.audio_data = [self.audio_data] * num - elif isinstance(self.audio_data, list): - pass - - if self.sampling_params is None: - self.sampling_params = [{}] * num - elif not isinstance(self.sampling_params, list): - self.sampling_params = [self.sampling_params] * num - - if self.rid is None: - self.rid = [uuid.uuid4().hex for _ in range(num)] + self.image_data = self.image_data * self.parallel_sample_num + self.modalities = self.modalities * self.parallel_sample_num else: - assert isinstance(self.rid, list), "The rid should be a list." + # List of images for a batch, wrap each in a list + wrapped_images = [[img] for img in self.image_data] + # Expand for parallel sampling + self.image_data = wrapped_images * self.parallel_sample_num + self.modalities = ["image"] * num - if self.return_logprob is None: - self.return_logprob = [False] * num - elif not isinstance(self.return_logprob, list): - self.return_logprob = [self.return_logprob] * num + def _normalize_audio_data(self, num): + """Normalize audio data for batch processing.""" + if self.audio_data is None: + self.audio_data = [None] * num + elif not isinstance(self.audio_data, list): + self.audio_data = [self.audio_data] * num + elif isinstance(self.audio_data, list): + self.audio_data = self.audio_data * self.parallel_sample_num + + def _normalize_sampling_params(self, num): + """Normalize sampling parameters for batch processing.""" + if self.sampling_params is None: + self.sampling_params = [{}] * num + elif isinstance(self.sampling_params, dict): + self.sampling_params = [self.sampling_params] * num + else: # Already a list + self.sampling_params = self.sampling_params * self.parallel_sample_num + + def _normalize_rid(self, num): + """Normalize request IDs for batch processing.""" + if self.rid is None: + self.rid = [uuid.uuid4().hex for _ in range(num)] + elif not isinstance(self.rid, list): + raise ValueError("The rid should be a list for batch processing.") + + def _normalize_logprob_params(self, num): + """Normalize logprob-related parameters for batch processing.""" + + # Helper function to normalize a parameter + def normalize_param(param, default_value, param_name): + if param is None: + return [default_value] * num + elif not isinstance(param, list): + return [param] * num else: - assert self.parallel_sample_num == 1 + if self.parallel_sample_num > 1: + raise ValueError( + f"Cannot use list {param_name} with parallel_sample_num > 1" + ) + return param - if self.logprob_start_len is None: - self.logprob_start_len = [-1] * num - elif not isinstance(self.logprob_start_len, list): - self.logprob_start_len = [self.logprob_start_len] * num - else: - assert self.parallel_sample_num == 1 + # Normalize each logprob parameter + self.return_logprob = normalize_param( + self.return_logprob, False, "return_logprob" + ) + self.logprob_start_len = normalize_param( + self.logprob_start_len, -1, "logprob_start_len" + ) + self.top_logprobs_num = normalize_param( + self.top_logprobs_num, 0, "top_logprobs_num" + ) - if self.top_logprobs_num is None: - self.top_logprobs_num = [0] * num - elif not isinstance(self.top_logprobs_num, list): - self.top_logprobs_num = [self.top_logprobs_num] * num - else: - assert self.parallel_sample_num == 1 - - if not self.token_ids_logprob: # covers both None and [] - self.token_ids_logprob = [None] * num - elif not isinstance(self.token_ids_logprob, list): - self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)] - elif not isinstance(self.token_ids_logprob[0], list): - self.token_ids_logprob = [ - copy.deepcopy(self.token_ids_logprob) for _ in range(num) - ] - else: - assert self.parallel_sample_num == 1 - - if self.custom_logit_processor is None: - self.custom_logit_processor = [None] * num - elif not isinstance(self.custom_logit_processor, list): - self.custom_logit_processor = [self.custom_logit_processor] * num - else: - assert self.parallel_sample_num == 1 - - # Other checks - if self.session_params is not None: - assert isinstance(self.session_params, dict) or isinstance( - self.session_params[0], dict + # Handle token_ids_logprob specially due to its nested structure + if not self.token_ids_logprob: # covers both None and [] + self.token_ids_logprob = [None] * num + elif not isinstance(self.token_ids_logprob, list): + self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)] + elif not isinstance(self.token_ids_logprob[0], list): + self.token_ids_logprob = [ + copy.deepcopy(self.token_ids_logprob) for _ in range(num) + ] + elif self.parallel_sample_num > 1: + raise ValueError( + "Cannot use list token_ids_logprob with parallel_sample_num > 1" ) + def _normalize_custom_logit_processor(self, num): + """Normalize custom logit processor for batch processing.""" + if self.custom_logit_processor is None: + self.custom_logit_processor = [None] * num + elif not isinstance(self.custom_logit_processor, list): + self.custom_logit_processor = [self.custom_logit_processor] * num + elif self.parallel_sample_num > 1: + raise ValueError( + "Cannot use list custom_logit_processor with parallel_sample_num > 1" + ) + + def _validate_session_params(self): + """Validate that session parameters are properly formatted.""" + if self.session_params is not None: + if not isinstance(self.session_params, dict) and not isinstance( + self.session_params[0], dict + ): + raise ValueError("Session params must be a dict or a list of dicts.") + def regenerate_rid(self): + """Generate a new request ID and return it.""" self.rid = uuid.uuid4().hex return self.rid @@ -305,8 +440,15 @@ class TokenizedGenerateReqInput: class EmbeddingReqInput: # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[str], str]] = None - # The image input. It can be a file name, a url, or base64 encoded string. - image_data: Optional[Union[List[str], str]] = None + # The image input. It can be an image instance, file name, URL, or base64 encoded string. + # Can be formatted as: + # - Single image for a single request + # - List of images (one per request in a batch) + # - List of lists of images (multiple images per request) + # See also python/sglang/srt/utils.py:load_image for more details. + image_data: Optional[ + Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]] + ] = None # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None # The request id. diff --git a/test/srt/test_io_struct.py b/test/srt/test_io_struct.py new file mode 100644 index 000000000..452b3e3a4 --- /dev/null +++ b/test/srt/test_io_struct.py @@ -0,0 +1,527 @@ +import copy +import unittest + +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + CustomTestCase, +) + + +class TestGenerateReqInputNormalization(CustomTestCase): + """Test the normalization of GenerateReqInput for batch processing and different input formats.""" + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + + def setUp(self): + # Common setup for all tests + self.base_req = GenerateReqInput( + text=["Hello", "World"], + sampling_params=[{}, {}], + rid=["id1", "id2"], + ) + + def test_single_image_to_list_of_lists(self): + """Test that a single image is converted to a list of single-image lists.""" + req = copy.deepcopy(self.base_req) + req.image_data = "single_image.jpg" # A single image (non-list) + + req.normalize_batch_and_arguments() + + # Should be converted to [[image], [image]] + self.assertEqual(len(req.image_data), 2) + self.assertEqual(len(req.image_data[0]), 1) + self.assertEqual(len(req.image_data[1]), 1) + self.assertEqual(req.image_data[0][0], "single_image.jpg") + self.assertEqual(req.image_data[1][0], "single_image.jpg") + + # Check modalities + self.assertEqual(req.modalities, ["image", "image"]) + + def test_list_of_images_to_list_of_lists(self): + """Test that a list of images is converted to a list of single-image lists.""" + req = copy.deepcopy(self.base_req) + req.image_data = ["image1.jpg", "image2.jpg"] # List of images + + req.normalize_batch_and_arguments() + + # Should be converted to [[image1], [image2]] + self.assertEqual(len(req.image_data), 2) + self.assertEqual(len(req.image_data[0]), 1) + self.assertEqual(len(req.image_data[1]), 1) + self.assertEqual(req.image_data[0][0], "image1.jpg") + self.assertEqual(req.image_data[1][0], "image2.jpg") + + # Check modalities + self.assertEqual(req.modalities, ["image", "image"]) + + def test_list_of_lists_with_different_modalities(self): + """Test handling of list of lists of images with different modalities.""" + req = copy.deepcopy(self.base_req) + req.image_data = [ + ["image1.jpg"], # Single image (image modality) + ["image2.jpg", "image3.jpg"], # Multiple images (multi-images modality) + ] + + req.normalize_batch_and_arguments() + + # Structure should remain the same + self.assertEqual(len(req.image_data), 2) + self.assertEqual(len(req.image_data[0]), 1) + self.assertEqual(len(req.image_data[1]), 2) + + # Check modalities + self.assertEqual(req.modalities, ["image", "multi-images"]) + + def test_list_of_lists_with_none_values(self): + """Test handling of list of lists with None values.""" + req = copy.deepcopy(self.base_req) + req.image_data = [ + [None], # None value + ["image.jpg"], # Single image + ] + + req.normalize_batch_and_arguments() + + # Structure should remain the same + self.assertEqual(len(req.image_data), 2) + self.assertEqual(len(req.image_data[0]), 1) + self.assertEqual(len(req.image_data[1]), 1) + + # Check modalities + self.assertEqual(req.modalities, [None, "image"]) + + def test_expanding_parallel_sample_correlation(self): + """Test that when expanding with parallel samples, prompts, images and modalities are properly correlated.""" + req = copy.deepcopy(self.base_req) + req.text = ["Prompt 1", "Prompt 2"] + req.image_data = [ + ["image1.jpg"], + ["image2.jpg", "image3.jpg"], + ] + req.sampling_params = {"n": 3} # All prompts get 3 samples + + # Define expected values before normalization + expected_text = req.text * 3 + expected_images = req.image_data * 3 + expected_modalities = ["image", "multi-images"] * 3 + + req.normalize_batch_and_arguments() + + # Should be expanded to 6 items (2 original * 3 parallel) + self.assertEqual(len(req.image_data), 6) + + # Check that images are properly expanded + self.assertEqual(req.image_data, expected_images) + + # Check modalities + self.assertEqual(req.modalities, expected_modalities) + + # Ensure that text items are properly duplicated too + self.assertEqual(req.text, expected_text) + + def test_list_of_lists_with_none_values(self): + """Test handling of list of lists with None values.""" + req = copy.deepcopy(self.base_req) + req.image_data = [ + [None], # None value + ["image.jpg"], # Single image + ] + + req.normalize_batch_and_arguments() + + # Structure should remain the same + self.assertEqual(len(req.image_data), 2) + self.assertEqual(len(req.image_data[0]), 1) + self.assertEqual(len(req.image_data[1]), 1) + + # Check modalities + self.assertEqual(req.modalities, [None, "image"]) + + def test_specific_parallel_n_per_sample(self): + """Test parallel expansion when different samples have different n values.""" + req = copy.deepcopy(self.base_req) + req.text = ["Prompt 1", "Prompt 2"] + req.image_data = [ + ["image1.jpg"], + ["image2.jpg", "image3.jpg"], + ] + req.sampling_params = [ + {"n": 2}, + {"n": 2}, + ] # First prompt gets 2 samples, second prompt gets 2 samples + + expected_images = req.image_data * 2 + expected_modalities = ["image", "multi-images"] * 2 + expected_text = req.text * 2 + + req.normalize_batch_and_arguments() + + # Should be expanded to 4 items (2 original * 2 parallel) + self.assertEqual(len(req.image_data), 4) + + # Check that the first 2 are copies for the first prompt + self.assertEqual(req.image_data, expected_images) + + # Check modalities + self.assertEqual(req.modalities, expected_modalities) + + # Check text expansion + self.assertEqual(req.text, expected_text) + + def test_mixed_none_and_images_with_parallel_samples(self): + """Test that when some batch items have images and others None, parallel expansion works correctly.""" + req = copy.deepcopy(self.base_req) + req.text = ["Prompt 1", "Prompt 2", "Prompt 3"] + req.image_data = [ + ["image1.jpg"], + None, + ["image3_1.jpg", "image3_2.jpg"], + ] + req.sampling_params = {"n": 2} # All prompts get 2 samples + + expected_images = req.image_data * 2 + expected_modalities = ["image", None, "multi-images"] * 2 + expected_text = req.text * 2 + + req.normalize_batch_and_arguments() + + # Should be expanded to 6 items (3 original * 2 parallel) + self.assertEqual(len(req.image_data), 6) + + # Check image data + self.assertEqual(req.image_data, expected_images) + + # Check modalities + self.assertEqual(req.modalities, expected_modalities) + + # Check text expansion + self.assertEqual(req.text, expected_text) + + def test_correlation_with_sampling_params(self): + """Test that sampling parameters are correctly correlated with prompts during expansion.""" + req = copy.deepcopy(self.base_req) + req.text = ["Prompt 1", "Prompt 2"] + req.image_data = [ + ["image1.jpg"], + ["image2.jpg"], + ] + req.sampling_params = [ + {"temperature": 0.7, "n": 2}, + {"temperature": 0.9, "n": 2}, + ] + + req.normalize_batch_and_arguments() + + # Check sampling params expansion + self.assertEqual(len(req.sampling_params), 4) + self.assertEqual(req.sampling_params[0]["temperature"], 0.7) + self.assertEqual(req.sampling_params[1]["temperature"], 0.9) + self.assertEqual(req.sampling_params[2]["temperature"], 0.7) + self.assertEqual(req.sampling_params[3]["temperature"], 0.9) + + # Should be expanded to 4 items (2 original * 2 parallel) + self.assertEqual(len(req.image_data), 4) + + # Check correlation with images + self.assertEqual(req.image_data[0], ["image1.jpg"]) + self.assertEqual(req.image_data[1], ["image2.jpg"]) + self.assertEqual(req.image_data[2], ["image1.jpg"]) + self.assertEqual(req.image_data[3], ["image2.jpg"]) + + def test_single_example_with_image(self): + """Test handling of single example with image.""" + req = GenerateReqInput( + text="Hello", + image_data="single_image.jpg", + ) + + req.normalize_batch_and_arguments() + + # For single examples, image_data doesn't get processed into lists + self.assertEqual(req.image_data, "single_image.jpg") + self.assertIsNone(req.modalities) # Modalities isn't set for single examples + + def test_single_to_batch_with_parallel_sampling(self): + """Test single example converted to batch with parallel sampling.""" + req = GenerateReqInput( + text="Hello", + image_data="single_image.jpg", + sampling_params={"n": 3}, # parallel_sample_num = 3 + ) + + # Define expected values before normalization + expected_text = ["Hello"] * 3 + + req.normalize_batch_and_arguments() + + # Should be converted to batch with text=["Hello"] + self.assertEqual(req.text, expected_text) + + # Image should be automatically wrapped to list of lists with length 1*3=3 + self.assertEqual(len(req.image_data), 3) + self.assertEqual(req.image_data[0][0], "single_image.jpg") + self.assertEqual(req.image_data[1][0], "single_image.jpg") + self.assertEqual(req.image_data[2][0], "single_image.jpg") + + # Modalities should be set for all 3 examples + self.assertEqual(req.modalities, ["image", "image", "image"]) + + def test_audio_data_handling(self): + """Test handling of audio_data.""" + req = copy.deepcopy(self.base_req) + req.audio_data = "audio.mp3" # Single audio + + req.normalize_batch_and_arguments() + + # Should be converted to ["audio.mp3", "audio.mp3"] + self.assertEqual(len(req.audio_data), 2) + self.assertEqual(req.audio_data[0], "audio.mp3") + self.assertEqual(req.audio_data[1], "audio.mp3") + + # Test with list + req = copy.deepcopy(self.base_req) + req.audio_data = ["audio1.mp3", "audio2.mp3"] + + req.normalize_batch_and_arguments() + + # Should remain the same + self.assertEqual(len(req.audio_data), 2) + self.assertEqual(req.audio_data[0], "audio1.mp3") + self.assertEqual(req.audio_data[1], "audio2.mp3") + + def test_input_ids_normalization(self): + """Test normalization of input_ids instead of text.""" + # Test single input_ids + req = GenerateReqInput(input_ids=[1, 2, 3]) + req.normalize_batch_and_arguments() + self.assertTrue(req.is_single) + self.assertEqual(req.batch_size, 1) + + # Test batch input_ids + req = GenerateReqInput(input_ids=[[1, 2, 3], [4, 5, 6]]) + req.normalize_batch_and_arguments() + self.assertFalse(req.is_single) + self.assertEqual(req.batch_size, 2) + + # Test with parallel sampling + req = GenerateReqInput( + input_ids=[[1, 2, 3], [4, 5, 6]], sampling_params={"n": 2} + ) + req.normalize_batch_and_arguments() + self.assertEqual(len(req.input_ids), 4) # 2 original * 2 parallel + + def test_input_embeds_normalization(self): + """Test normalization of input_embeds.""" + # Test single input_embeds + req = GenerateReqInput(input_embeds=[[0.1, 0.2], [0.3, 0.4]]) + req.normalize_batch_and_arguments() + self.assertTrue(req.is_single) + self.assertEqual(req.batch_size, 1) + + # Test batch input_embeds + req = GenerateReqInput(input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]]) + req.normalize_batch_and_arguments() + self.assertFalse(req.is_single) + self.assertEqual(req.batch_size, 2) + + def test_lora_path_normalization(self): + """Test normalization of lora_path.""" + # Test single lora_path with batch input + req = GenerateReqInput(text=["Hello", "World"], lora_path="path/to/lora") + + # Define expected lora_paths before normalization + expected_lora_paths = ["path/to/lora", "path/to/lora"] + + req.normalize_batch_and_arguments() + self.assertEqual(req.lora_path, expected_lora_paths) + + # Test list of lora_paths + req = GenerateReqInput(text=["Hello", "World"], lora_path=["path1", "path2"]) + + # Define expected lora_paths before normalization + expected_lora_paths = ["path1", "path2"] + + req.normalize_batch_and_arguments() + self.assertEqual(req.lora_path, expected_lora_paths) + + # Test with parallel sampling + req = GenerateReqInput( + text=["Hello", "World"], + lora_path=["path1", "path2"], + sampling_params={"n": 2}, + ) + + # Define expected lora_paths before normalization + expected_lora_paths = ["path1", "path2"] * 2 + + req.normalize_batch_and_arguments() + self.assertEqual(req.lora_path, expected_lora_paths) + + def test_logprob_parameters_normalization(self): + """Test normalization of logprob-related parameters.""" + # Test single example + req = GenerateReqInput( + text="Hello", + return_logprob=True, + logprob_start_len=10, + top_logprobs_num=5, + token_ids_logprob=[7, 8, 9], + ) + req.normalize_batch_and_arguments() + self.assertEqual(req.return_logprob, True) + self.assertEqual(req.logprob_start_len, 10) + self.assertEqual(req.top_logprobs_num, 5) + self.assertEqual(req.token_ids_logprob, [7, 8, 9]) + + # Test batch with scalar values + req = GenerateReqInput( + text=["Hello", "World"], + return_logprob=True, + logprob_start_len=10, + top_logprobs_num=5, + token_ids_logprob=[7, 8, 9], + ) + req.normalize_batch_and_arguments() + self.assertEqual(req.return_logprob, [True, True]) + self.assertEqual(req.logprob_start_len, [10, 10]) + self.assertEqual(req.top_logprobs_num, [5, 5]) + self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [7, 8, 9]]) + + # Test batch with list values + req = GenerateReqInput( + text=["Hello", "World"], + return_logprob=[True, False], + logprob_start_len=[10, 5], + top_logprobs_num=[5, 3], + token_ids_logprob=[[7, 8, 9], [4, 5, 6]], + ) + req.normalize_batch_and_arguments() + self.assertEqual(req.return_logprob, [True, False]) + self.assertEqual(req.logprob_start_len, [10, 5]) + self.assertEqual(req.top_logprobs_num, [5, 3]) + self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]]) + + def test_custom_logit_processor_normalization(self): + """Test normalization of custom_logit_processor.""" + # Test single processor + req = GenerateReqInput( + text=["Hello", "World"], custom_logit_processor="serialized_processor" + ) + req.normalize_batch_and_arguments() + self.assertEqual( + req.custom_logit_processor, ["serialized_processor", "serialized_processor"] + ) + + # Test list of processors + req = GenerateReqInput( + text=["Hello", "World"], custom_logit_processor=["processor1", "processor2"] + ) + req.normalize_batch_and_arguments() + self.assertEqual(req.custom_logit_processor, ["processor1", "processor2"]) + + def test_session_params_handling(self): + """Test handling of session_params.""" + # Test with dict + req = GenerateReqInput( + text=["Hello", "World"], session_params={"id": "session1", "offset": 10} + ) + req.normalize_batch_and_arguments() + self.assertEqual(req.session_params, {"id": "session1", "offset": 10}) + + # Test with list of dicts + req = GenerateReqInput( + text=["Hello", "World"], + session_params=[{"id": "session1"}, {"id": "session2"}], + ) + req.normalize_batch_and_arguments() + self.assertEqual(req.session_params, [{"id": "session1"}, {"id": "session2"}]) + + def test_getitem_method(self): + """Test the __getitem__ method.""" + req = GenerateReqInput( + text=["Hello", "World"], + image_data=[["img1.jpg"], ["img2.jpg"]], + audio_data=["audio1.mp3", "audio2.mp3"], + sampling_params=[{"temp": 0.7}, {"temp": 0.8}], + rid=["id1", "id2"], + return_logprob=[True, False], + logprob_start_len=[10, 5], + top_logprobs_num=[5, 3], + token_ids_logprob=[[7, 8, 9], [4, 5, 6]], + stream=True, + log_metrics=True, + modalities=["image", "image"], + lora_path=["path1", "path2"], + custom_logit_processor=["processor1", "processor2"], + return_hidden_states=True, + ) + req.normalize_batch_and_arguments() + + # Get the first item + item0 = req[0] + self.assertEqual(item0.text, "Hello") + self.assertEqual(item0.image_data, ["img1.jpg"]) + self.assertEqual(item0.audio_data, "audio1.mp3") + self.assertEqual(item0.sampling_params, {"temp": 0.7}) + self.assertEqual(item0.rid, "id1") + self.assertEqual(item0.return_logprob, True) + self.assertEqual(item0.logprob_start_len, 10) + self.assertEqual(item0.top_logprobs_num, 5) + self.assertEqual(item0.token_ids_logprob, [7, 8, 9]) + self.assertEqual(item0.stream, True) + self.assertEqual(item0.log_metrics, True) + self.assertEqual(item0.modalities, "image") + self.assertEqual(item0.lora_path, "path1") + self.assertEqual(item0.custom_logit_processor, "processor1") + self.assertEqual(item0.return_hidden_states, True) + + def test_regenerate_rid(self): + """Test the regenerate_rid method.""" + req = GenerateReqInput(text="Hello") + req.normalize_batch_and_arguments() + + original_rid = req.rid + new_rid = req.regenerate_rid() + + self.assertNotEqual(original_rid, new_rid) + self.assertEqual(req.rid, new_rid) + + def test_error_cases(self): + """Test various error cases.""" + # Test when neither text, input_ids, nor input_embeds is provided + with self.assertRaises(ValueError): + req = GenerateReqInput() + req.normalize_batch_and_arguments() + + # Test when all of text, input_ids, and input_embeds are provided + with self.assertRaises(ValueError): + req = GenerateReqInput( + text="Hello", input_ids=[1, 2, 3], input_embeds=[[0.1, 0.2]] + ) + req.normalize_batch_and_arguments() + + def test_multiple_input_formats(self): + """Test different combinations of input formats.""" + # Test with text only + req = GenerateReqInput(text="Hello") + req.normalize_batch_and_arguments() + self.assertTrue(req.is_single) + + # Test with input_ids only + req = GenerateReqInput(input_ids=[1, 2, 3]) + req.normalize_batch_and_arguments() + self.assertTrue(req.is_single) + + # Test with input_embeds only + req = GenerateReqInput(input_embeds=[[0.1, 0.2]]) + req.normalize_batch_and_arguments() + self.assertTrue(req.is_single) + + +if __name__ == "__main__": + unittest.main()