fix: Handles input_embeds in GenerateReqInput when n>1 (#7830)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-07-08 14:00:42 -07:00
committed by GitHub
parent 43e20c0647
commit 136c6e0431
3 changed files with 75 additions and 1 deletions

View File

@@ -200,6 +200,8 @@ class GenerateReqInput:
self.text = [self.text] self.text = [self.text]
if self.input_ids is not None: if self.input_ids is not None:
self.input_ids = [self.input_ids] self.input_ids = [self.input_ids]
if self.input_embeds is not None:
self.input_embeds = [self.input_embeds]
def _normalize_single_inputs(self): def _normalize_single_inputs(self):
"""Normalize inputs for a single example.""" """Normalize inputs for a single example."""
@@ -324,7 +326,9 @@ class GenerateReqInput:
new_rids = [f"{self.rid}_{i}" for i in range(num)] new_rids = [f"{self.rid}_{i}" for i in range(num)]
self.rid = new_rids self.rid = new_rids
elif isinstance(self.rid, list): elif isinstance(self.rid, list):
if len(self.rid) != num: # Note: the length of rid shall be the same as the batch_size,
# as the rid would be expanded for parallel sampling in tokenizer_manager
if len(self.rid) != self.batch_size:
raise ValueError( raise ValueError(
"The specified rids length mismatch with the batch_size for batch processing." "The specified rids length mismatch with the batch_size for batch processing."
) )
@@ -400,6 +404,9 @@ class GenerateReqInput:
return GenerateReqInput( return GenerateReqInput(
text=self.text[i] if self.text is not None else None, text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None,
input_embeds=(
self.input_embeds[i] if self.input_embeds is not None else None
),
image_data=self.image_data[i], image_data=self.image_data[i],
audio_data=self.audio_data[i], audio_data=self.audio_data[i],
sampling_params=self.sampling_params[i], sampling_params=self.sampling_params[i],

View File

@@ -67,6 +67,7 @@ suites = {
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
TestFile("test_int8_kernel.py", 8), TestFile("test_int8_kernel.py", 8),
TestFile("test_input_embeddings.py", 38), TestFile("test_input_embeddings.py", 38),
TestFile("test_io_struct.py", 8),
TestFile("test_jinja_template_utils.py", 1), TestFile("test_jinja_template_utils.py", 1),
TestFile("test_metrics.py", 32), TestFile("test_metrics.py", 32),
TestFile("test_mla.py", 167), TestFile("test_mla.py", 167),

View File

@@ -159,6 +159,7 @@ class TestGenerateReqInputNormalization(CustomTestCase):
"""Test that when some batch items have images and others None, parallel expansion works correctly.""" """Test that when some batch items have images and others None, parallel expansion works correctly."""
req = copy.deepcopy(self.base_req) req = copy.deepcopy(self.base_req)
req.text = ["Prompt 1", "Prompt 2", "Prompt 3"] req.text = ["Prompt 1", "Prompt 2", "Prompt 3"]
req.rid = ["id1", "id2", "id3"]
req.image_data = [ req.image_data = [
["image1.jpg"], ["image1.jpg"],
None, None,
@@ -311,6 +312,71 @@ class TestGenerateReqInputNormalization(CustomTestCase):
self.assertFalse(req.is_single) self.assertFalse(req.is_single)
self.assertEqual(req.batch_size, 2) self.assertEqual(req.batch_size, 2)
def test_input_embeds_with_parallel_sampling(self):
"""Test input_embeds normalization with parallel sampling (n > 1)."""
# Test single input_embeds with parallel sampling
req = GenerateReqInput(
input_embeds=[[0.1, 0.2]], # single embedding vector
sampling_params={"n": 2},
)
req.normalize_batch_and_arguments()
# Should be converted from single to batch and then expanded
self.assertFalse(req.is_single)
self.assertEqual(len(req.input_embeds), 2)
# Both should be the same input_embeds
self.assertEqual(req.input_embeds[0], [[0.1, 0.2]])
self.assertEqual(req.input_embeds[1], [[0.1, 0.2]])
# Test batch input_embeds with parallel sampling
req = GenerateReqInput(
input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]], sampling_params={"n": 3}
)
req.normalize_batch_and_arguments()
# Should be expanded
self.assertFalse(req.is_single)
self.assertEqual(len(req.input_embeds), 6)
# Check that the expansion is correct
expected_embeds = [[[0.1, 0.2]], [[0.3, 0.4]]] * 3
self.assertEqual(req.input_embeds, expected_embeds)
# Test with different n values per sample (should raise error)
req = GenerateReqInput(
input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]],
sampling_params=[{"n": 2}, {"n": 3}],
)
with self.assertRaises(ValueError):
req.normalize_batch_and_arguments()
def test_input_embeds_single_to_batch_conversion(self):
"""Test that single input_embeds are properly converted to batch when using parallel sampling."""
# Test the specific case that was fixed: single input_embeds with n > 1
req = GenerateReqInput(
input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 2} # Single embedding
)
req.normalize_batch_and_arguments()
# Should convert single to batch and then expand
self.assertFalse(req.is_single)
self.assertEqual(len(req.input_embeds), 2)
# Both should be the same single embedding
self.assertEqual(req.input_embeds[0], [[0.1, 0.2, 0.3]])
self.assertEqual(req.input_embeds[1], [[0.1, 0.2, 0.3]])
# Test with higher n value
req = GenerateReqInput(input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 5})
req.normalize_batch_and_arguments()
self.assertFalse(req.is_single)
self.assertEqual(len(req.input_embeds), 5)
# All should be the same
for i in range(5):
self.assertEqual(req.input_embeds[i], [[0.1, 0.2, 0.3]])
def test_lora_path_normalization(self): def test_lora_path_normalization(self):
"""Test normalization of lora_path.""" """Test normalization of lora_path."""
# Test single lora_path with batch input # Test single lora_path with batch input