From 136c6e0431c2067c3a2a98ad2c77fc89a9cb98e7 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Tue, 8 Jul 2025 14:00:42 -0700 Subject: [PATCH] fix: Handles input_embeds in GenerateReqInput when n>1 (#7830) Signed-off-by: Xinyuan Tong --- python/sglang/srt/managers/io_struct.py | 9 +++- test/srt/run_suite.py | 1 + test/srt/test_io_struct.py | 66 +++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index f3a59c03c..d0cc3e5d6 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -200,6 +200,8 @@ class GenerateReqInput: self.text = [self.text] if self.input_ids is not None: self.input_ids = [self.input_ids] + if self.input_embeds is not None: + self.input_embeds = [self.input_embeds] def _normalize_single_inputs(self): """Normalize inputs for a single example.""" @@ -324,7 +326,9 @@ class GenerateReqInput: new_rids = [f"{self.rid}_{i}" for i in range(num)] self.rid = new_rids elif isinstance(self.rid, list): - if len(self.rid) != num: + # Note: the length of rid shall be the same as the batch_size, + # as the rid would be expanded for parallel sampling in tokenizer_manager + if len(self.rid) != self.batch_size: raise ValueError( "The specified rids length mismatch with the batch_size for batch processing." ) @@ -400,6 +404,9 @@ class GenerateReqInput: return GenerateReqInput( text=self.text[i] if self.text is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None, + input_embeds=( + self.input_embeds[i] if self.input_embeds is not None else None + ), image_data=self.image_data[i], audio_data=self.audio_data[i], sampling_params=self.sampling_params[i], diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b2c8c9252..a3c8e1a8d 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -67,6 +67,7 @@ suites = { TestFile("test_hidden_states.py", 55), TestFile("test_int8_kernel.py", 8), TestFile("test_input_embeddings.py", 38), + TestFile("test_io_struct.py", 8), TestFile("test_jinja_template_utils.py", 1), TestFile("test_metrics.py", 32), TestFile("test_mla.py", 167), diff --git a/test/srt/test_io_struct.py b/test/srt/test_io_struct.py index b8fdec8ec..a8efebb23 100644 --- a/test/srt/test_io_struct.py +++ b/test/srt/test_io_struct.py @@ -159,6 +159,7 @@ class TestGenerateReqInputNormalization(CustomTestCase): """Test that when some batch items have images and others None, parallel expansion works correctly.""" req = copy.deepcopy(self.base_req) req.text = ["Prompt 1", "Prompt 2", "Prompt 3"] + req.rid = ["id1", "id2", "id3"] req.image_data = [ ["image1.jpg"], None, @@ -311,6 +312,71 @@ class TestGenerateReqInputNormalization(CustomTestCase): self.assertFalse(req.is_single) self.assertEqual(req.batch_size, 2) + def test_input_embeds_with_parallel_sampling(self): + """Test input_embeds normalization with parallel sampling (n > 1).""" + # Test single input_embeds with parallel sampling + req = GenerateReqInput( + input_embeds=[[0.1, 0.2]], # single embedding vector + sampling_params={"n": 2}, + ) + req.normalize_batch_and_arguments() + + # Should be converted from single to batch and then expanded + self.assertFalse(req.is_single) + self.assertEqual(len(req.input_embeds), 2) + # Both should be the same input_embeds + self.assertEqual(req.input_embeds[0], [[0.1, 0.2]]) + self.assertEqual(req.input_embeds[1], [[0.1, 0.2]]) + + # Test batch input_embeds with parallel sampling + req = GenerateReqInput( + input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]], sampling_params={"n": 3} + ) + req.normalize_batch_and_arguments() + + # Should be expanded + self.assertFalse(req.is_single) + self.assertEqual(len(req.input_embeds), 6) + + # Check that the expansion is correct + expected_embeds = [[[0.1, 0.2]], [[0.3, 0.4]]] * 3 + self.assertEqual(req.input_embeds, expected_embeds) + + # Test with different n values per sample (should raise error) + req = GenerateReqInput( + input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]], + sampling_params=[{"n": 2}, {"n": 3}], + ) + with self.assertRaises(ValueError): + req.normalize_batch_and_arguments() + + def test_input_embeds_single_to_batch_conversion(self): + """Test that single input_embeds are properly converted to batch when using parallel sampling.""" + # Test the specific case that was fixed: single input_embeds with n > 1 + req = GenerateReqInput( + input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 2} # Single embedding + ) + req.normalize_batch_and_arguments() + + # Should convert single to batch and then expand + self.assertFalse(req.is_single) + self.assertEqual(len(req.input_embeds), 2) + + # Both should be the same single embedding + self.assertEqual(req.input_embeds[0], [[0.1, 0.2, 0.3]]) + self.assertEqual(req.input_embeds[1], [[0.1, 0.2, 0.3]]) + + # Test with higher n value + req = GenerateReqInput(input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 5}) + req.normalize_batch_and_arguments() + + self.assertFalse(req.is_single) + self.assertEqual(len(req.input_embeds), 5) + + # All should be the same + for i in range(5): + self.assertEqual(req.input_embeds[i], [[0.1, 0.2, 0.3]]) + def test_lora_path_normalization(self): """Test normalization of lora_path.""" # Test single lora_path with batch input