fix: Handles input_embeds in GenerateReqInput when n>1 (#7830)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -200,6 +200,8 @@ class GenerateReqInput:
|
||||
self.text = [self.text]
|
||||
if self.input_ids is not None:
|
||||
self.input_ids = [self.input_ids]
|
||||
if self.input_embeds is not None:
|
||||
self.input_embeds = [self.input_embeds]
|
||||
|
||||
def _normalize_single_inputs(self):
|
||||
"""Normalize inputs for a single example."""
|
||||
@@ -324,7 +326,9 @@ class GenerateReqInput:
|
||||
new_rids = [f"{self.rid}_{i}" for i in range(num)]
|
||||
self.rid = new_rids
|
||||
elif isinstance(self.rid, list):
|
||||
if len(self.rid) != num:
|
||||
# Note: the length of rid shall be the same as the batch_size,
|
||||
# as the rid would be expanded for parallel sampling in tokenizer_manager
|
||||
if len(self.rid) != self.batch_size:
|
||||
raise ValueError(
|
||||
"The specified rids length mismatch with the batch_size for batch processing."
|
||||
)
|
||||
@@ -400,6 +404,9 @@ class GenerateReqInput:
|
||||
return GenerateReqInput(
|
||||
text=self.text[i] if self.text is not None else None,
|
||||
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
||||
input_embeds=(
|
||||
self.input_embeds[i] if self.input_embeds is not None else None
|
||||
),
|
||||
image_data=self.image_data[i],
|
||||
audio_data=self.audio_data[i],
|
||||
sampling_params=self.sampling_params[i],
|
||||
|
||||
@@ -67,6 +67,7 @@ suites = {
|
||||
TestFile("test_hidden_states.py", 55),
|
||||
TestFile("test_int8_kernel.py", 8),
|
||||
TestFile("test_input_embeddings.py", 38),
|
||||
TestFile("test_io_struct.py", 8),
|
||||
TestFile("test_jinja_template_utils.py", 1),
|
||||
TestFile("test_metrics.py", 32),
|
||||
TestFile("test_mla.py", 167),
|
||||
|
||||
@@ -159,6 +159,7 @@ class TestGenerateReqInputNormalization(CustomTestCase):
|
||||
"""Test that when some batch items have images and others None, parallel expansion works correctly."""
|
||||
req = copy.deepcopy(self.base_req)
|
||||
req.text = ["Prompt 1", "Prompt 2", "Prompt 3"]
|
||||
req.rid = ["id1", "id2", "id3"]
|
||||
req.image_data = [
|
||||
["image1.jpg"],
|
||||
None,
|
||||
@@ -311,6 +312,71 @@ class TestGenerateReqInputNormalization(CustomTestCase):
|
||||
self.assertFalse(req.is_single)
|
||||
self.assertEqual(req.batch_size, 2)
|
||||
|
||||
def test_input_embeds_with_parallel_sampling(self):
|
||||
"""Test input_embeds normalization with parallel sampling (n > 1)."""
|
||||
# Test single input_embeds with parallel sampling
|
||||
req = GenerateReqInput(
|
||||
input_embeds=[[0.1, 0.2]], # single embedding vector
|
||||
sampling_params={"n": 2},
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should be converted from single to batch and then expanded
|
||||
self.assertFalse(req.is_single)
|
||||
self.assertEqual(len(req.input_embeds), 2)
|
||||
# Both should be the same input_embeds
|
||||
self.assertEqual(req.input_embeds[0], [[0.1, 0.2]])
|
||||
self.assertEqual(req.input_embeds[1], [[0.1, 0.2]])
|
||||
|
||||
# Test batch input_embeds with parallel sampling
|
||||
req = GenerateReqInput(
|
||||
input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]], sampling_params={"n": 3}
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should be expanded
|
||||
self.assertFalse(req.is_single)
|
||||
self.assertEqual(len(req.input_embeds), 6)
|
||||
|
||||
# Check that the expansion is correct
|
||||
expected_embeds = [[[0.1, 0.2]], [[0.3, 0.4]]] * 3
|
||||
self.assertEqual(req.input_embeds, expected_embeds)
|
||||
|
||||
# Test with different n values per sample (should raise error)
|
||||
req = GenerateReqInput(
|
||||
input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]],
|
||||
sampling_params=[{"n": 2}, {"n": 3}],
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
def test_input_embeds_single_to_batch_conversion(self):
|
||||
"""Test that single input_embeds are properly converted to batch when using parallel sampling."""
|
||||
# Test the specific case that was fixed: single input_embeds with n > 1
|
||||
req = GenerateReqInput(
|
||||
input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 2} # Single embedding
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
# Should convert single to batch and then expand
|
||||
self.assertFalse(req.is_single)
|
||||
self.assertEqual(len(req.input_embeds), 2)
|
||||
|
||||
# Both should be the same single embedding
|
||||
self.assertEqual(req.input_embeds[0], [[0.1, 0.2, 0.3]])
|
||||
self.assertEqual(req.input_embeds[1], [[0.1, 0.2, 0.3]])
|
||||
|
||||
# Test with higher n value
|
||||
req = GenerateReqInput(input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 5})
|
||||
req.normalize_batch_and_arguments()
|
||||
|
||||
self.assertFalse(req.is_single)
|
||||
self.assertEqual(len(req.input_embeds), 5)
|
||||
|
||||
# All should be the same
|
||||
for i in range(5):
|
||||
self.assertEqual(req.input_embeds[i], [[0.1, 0.2, 0.3]])
|
||||
|
||||
def test_lora_path_normalization(self):
|
||||
"""Test normalization of lora_path."""
|
||||
# Test single lora_path with batch input
|
||||
|
||||
Reference in New Issue
Block a user