fix: Handles input_embeds in GenerateReqInput when n>1 (#7830)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-07-08 14:00:42 -07:00
committed by GitHub
parent 43e20c0647
commit 136c6e0431
3 changed files with 75 additions and 1 deletions

View File

@@ -200,6 +200,8 @@ class GenerateReqInput:
self.text = [self.text]
if self.input_ids is not None:
self.input_ids = [self.input_ids]
if self.input_embeds is not None:
self.input_embeds = [self.input_embeds]
def _normalize_single_inputs(self):
"""Normalize inputs for a single example."""
@@ -324,7 +326,9 @@ class GenerateReqInput:
new_rids = [f"{self.rid}_{i}" for i in range(num)]
self.rid = new_rids
elif isinstance(self.rid, list):
if len(self.rid) != num:
# Note: the length of rid shall be the same as the batch_size,
# as the rid would be expanded for parallel sampling in tokenizer_manager
if len(self.rid) != self.batch_size:
raise ValueError(
"The specified rids length mismatch with the batch_size for batch processing."
)
@@ -400,6 +404,9 @@ class GenerateReqInput:
return GenerateReqInput(
text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None,
input_embeds=(
self.input_embeds[i] if self.input_embeds is not None else None
),
image_data=self.image_data[i],
audio_data=self.audio_data[i],
sampling_params=self.sampling_params[i],